diff --git a/pytype/abstract/_base.py b/pytype/abstract/_base.py index 821d539ff..30a448268 100644 --- a/pytype/abstract/_base.py +++ b/pytype/abstract/_base.py @@ -4,13 +4,16 @@ this module; use the alias in abstract.py instead. """ -from typing import Any +from typing import TypeVar, Any from pytype import utils from pytype.abstract import abstract_utils from pytype.pytd import mro from pytype.types import types +_T0 = TypeVar("_T0") +_TBaseValue = TypeVar("_TBaseValue", bound="BaseValue") + _isinstance = abstract_utils._isinstance # pylint: disable=protected-access _make = abstract_utils._make # pylint: disable=protected-access @@ -29,7 +32,7 @@ class BaseValue(utils.ContextWeakrefMixin, types.BaseValue): formal = False # is this type non-instantiable? - def __init__(self, name, ctx): + def __init__(self, name, ctx) -> None: """Basic initializer for all BaseValues.""" super().__init__(ctx) # This default cls value is used by things like Unsolvable that inherit @@ -94,15 +97,15 @@ def full_name(self): def __repr__(self): return self.name - def compute_mro(self): + def compute_mro(self) -> tuple: # default for objects with no MRO return () - def default_mro(self): + def default_mro(self) -> tuple[types.BaseValue, Any]: # default for objects with unknown MRO return (self, self.ctx.convert.object_type) - def get_default_fullhash(self): + def get_default_fullhash(self) -> int: return id(self) def get_fullhash(self, seen=None): @@ -143,7 +146,7 @@ def get_formal_type_parameter(self, t): del t return self.ctx.convert.unsolvable - def property_get(self, callself, is_class=False): + def property_get(self: _TBaseValue, callself, is_class=False) -> _TBaseValue: """Bind this value to the given self or cls. This function is similar to __get__ except at the abstract level. This does @@ -181,7 +184,7 @@ def get_special_attribute(self, unused_node, name, unused_valself): return self.cls.to_variable(self.ctx.root_node) return None - def get_own_new(self, node, value): + def get_own_new(self, node: _T0, value) -> tuple[_T0, None]: """Get this value's __new__ method, if it isn't object.__new__.""" del value # Unused, only classes have methods. return node, None @@ -214,7 +217,7 @@ def argcount(self, node): """Returns the minimum number of arguments needed for a call.""" raise NotImplementedError(self.__class__.__name__) - def register_instance(self, instance): # pylint: disable=unused-arg + def register_instance(self, instance) -> None: # pylint: disable=unused-arg """Treating self as a class definition, register an instance of it. This is used for keeping merging call records on instances when generating @@ -240,7 +243,7 @@ def to_pytd_def(self, node, name): """Get a PyTD definition for this object.""" return self.ctx.pytd_convert.value_to_pytd_def(node, self, name) - def get_default_type_key(self): + def get_default_type_key(self) -> type[types.BaseValue]: """Gets a default type key. See get_type_key.""" return type(self) @@ -304,15 +307,15 @@ def to_binding(self, node): (binding,) = self.to_variable(node).bindings return binding - def has_varargs(self): + def has_varargs(self) -> bool: """Return True if this is a function and has a *args parameter.""" return False - def has_kwargs(self): + def has_kwargs(self) -> bool: """Return True if this is a function and has a **kwargs parameter.""" return False - def _unique_parameters(self): + def _unique_parameters(self) -> list[None]: """Get unique parameter subtypes as variables. This will retrieve 'children' of this value that contribute to the @@ -339,7 +342,7 @@ def _get_values(parameter): return [_get_values(parameter) for parameter in self._unique_parameters()] - def init_subclass(self, node, cls): + def init_subclass(self, node: _T0, cls) -> _T0: """Allow metaprogramming via __init_subclass__. We do not analyse __init_subclass__ methods in the code, but overlays that @@ -363,13 +366,13 @@ def init_subclass(self, node, cls): del cls return node - def update_official_name(self, _): + def update_official_name(self, _) -> None: """Update the official name.""" - def is_late_annotation(self): + def is_late_annotation(self) -> bool: return False - def should_set_self_annot(self): + def should_set_self_annot(self) -> bool: # To do argument matching for custom generic classes, the 'self' annotation # needs to be replaced with a generic type. diff --git a/pytype/abstract/_classes.py b/pytype/abstract/_classes.py index cd0b2a3d7..03180646f 100644 --- a/pytype/abstract/_classes.py +++ b/pytype/abstract/_classes.py @@ -1,7 +1,8 @@ """Abstract class representations.""" +from collections.abc import Generator import logging -from typing import Any +from typing import Any, TypeVar from pytype import datatypes from pytype.abstract import _base @@ -21,12 +22,17 @@ from pytype.typegraph import cfg from pytype.types import types -log = logging.getLogger(__name__) +_T0 = TypeVar("_T0") +_TParameterizedClass = TypeVar( + "_TParameterizedClass", bound="ParameterizedClass" +) + +log: logging.Logger = logging.getLogger(__name__) _isinstance = abstract_utils._isinstance # pylint: disable=protected-access # These classes can't be imported due to circular deps. -_ContextType = Any # context.Context -_TypeParamType = Any # typing.TypeParameter +_ContextType: Any = Any # context.Context +_TypeParamType: Any = Any # typing.TypeParameter class BuildClass(_base.BaseValue): @@ -34,11 +40,11 @@ class BuildClass(_base.BaseValue): CLOSURE_NAME = "__class__" - def __init__(self, ctx): + def __init__(self, ctx) -> None: super().__init__("__build_class__", ctx) self.decorators = [] - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: args = args.simplify(node, self.ctx) funcvar, name = args.posargs[0:2] kwargs = args.namedargs @@ -141,7 +147,7 @@ def __init__( self._override_check() self._first_opcode = first_opcode - def _get_class(self): + def _get_class(self) -> "ParameterizedClass": return ParameterizedClass( self.ctx.convert.type_type, {abstract_utils.T: self}, self.ctx ) @@ -149,7 +155,7 @@ def _get_class(self): def get_first_opcode(self): return self._first_opcode - def update_method_type_params(self): + def update_method_type_params(self) -> None: # For function type parameters check methods = [] # members of self._undecorated_methods that will be ignored for updating @@ -171,7 +177,7 @@ def update_method_type_params(self): for m in methods: m.update_signature_scope(self) - def _type_param_check(self): + def _type_param_check(self) -> None: """Throw exception for invalid type parameters.""" self.update_method_type_params() if self.template: @@ -197,7 +203,7 @@ def _type_param_check(self): self, f"Conflicting value for TypeVar {t.full_name}" ) - def _override_check(self): + def _override_check(self) -> None: """Checks for @typing.override errors.""" for name, member in self.members.items(): member_data = [ @@ -239,7 +245,7 @@ def _get_defining_base_class(self, attr): return base return None - def collect_inner_cls_types(self, max_depth=5): + def collect_inner_cls_types(self, max_depth=5) -> set: """Collect all the type parameters from nested classes.""" templates = set() if max_depth > 0: @@ -254,7 +260,7 @@ def collect_inner_cls_types(self, max_depth=5): templates.update(mbr.collect_inner_cls_types(max_depth - 1)) return templates - def get_inner_classes(self): + def get_inner_classes(self) -> list: """Return the list of top-level nested classes.""" inner_classes = [] for member in self.members.values(): @@ -272,7 +278,7 @@ def get_inner_classes(self): inner_classes.append(value) return inner_classes - def get_own_attributes(self): + def get_own_attributes(self) -> set[str]: attributes = set(self.members) annotations_dict = abstract_utils.get_annotations_dict(self.members) if annotations_dict: @@ -285,13 +291,13 @@ def _can_be_abstract(var): return {name for name, var in self.members.items() if _can_be_abstract(var)} - def register_instance(self, instance): + def register_instance(self, instance) -> None: self.instances.add(instance) - def register_canonical_instance(self, instance): + def register_canonical_instance(self, instance) -> None: self.canonical_instances.add(instance) - def bases(self): + def bases(self) -> list[cfg.Variable]: return self._bases def metaclass(self, node): @@ -312,7 +318,7 @@ def instantiate(self, node, container=None): # the frame is a SimpleFrame with no opcode. return super().instantiate(node, container) - def __repr__(self): + def __repr__(self) -> str: return f"InterpreterClass({self.name})" def __contains__(self, name): @@ -321,7 +327,7 @@ def __contains__(self, name): annotations_dict = abstract_utils.get_annotations_dict(self.members) return annotations_dict and name in annotations_dict.annotated_locals - def has_protocol_base(self): + def has_protocol_base(self) -> bool: for base_var in self._bases: for base in base_var.data: if isinstance(base, PyTDClass) and base.full_name == "typing.Protocol": @@ -350,7 +356,7 @@ class PyTDClass( mro: Method resolution order. An iterable of BaseValue. """ - def __init__(self, name, pytd_cls, ctx): + def __init__(self, name, pytd_cls, ctx) -> None: # Apply decorators first, in case they set any properties that later # initialization code needs to read. self.has_explicit_init = any(x.name == "__init__" for x in pytd_cls.methods) @@ -424,7 +430,7 @@ def make(cls, name, pytd_cls, ctx): # If none of the special classes have matched, return the PyTDClass return c - def _populate_decorator_metadata(self): + def _populate_decorator_metadata(self) -> None: """Fill in class attribute metadata for decorators like @dataclass.""" keyed_decorators = {} for decorator in self.decorators: @@ -454,7 +460,7 @@ def _populate_decorator_metadata(self): self._init_attr_metadata_from_pytd(decorator) self._recompute_init_from_metadata(key) - def _init_attr_metadata_from_pytd(self, decorator): + def _init_attr_metadata_from_pytd(self, decorator) -> None: """Initialise metadata[key] with a list of Attributes.""" # Use the __init__ function as the source of truth for dataclass fields; if # this is a generated module we will have already processed ClassVar and @@ -485,7 +491,7 @@ def _init_attr_metadata_from_pytd(self, decorator): ] self.compute_attr_metadata(own_attrs, decorator) - def _recompute_init_from_metadata(self, key): + def _recompute_init_from_metadata(self, key) -> None: # Some decorated classes (dataclasses e.g.) have their __init__ function # set via traversing the MRO to collect initializers from decorated parent # classes as well. Since we don't have access to the MRO when initially @@ -510,7 +516,7 @@ def get_own_abstract_methods(self): if isinstance(member, pytd.Function) and member.is_abstract } - def bases(self): + def bases(self) -> list: convert = self.ctx.convert converted_bases = [] for base in self.pytd_cls.bases: @@ -561,7 +567,9 @@ def _convert_member(self, name, member, subst=None): else: raise AssertionError(f"Invalid class member {pytd_utils.Print(member)}") - def _new_instance(self, container, node, args): + def _new_instance( + self, container, node, args + ) -> _instance_base.Instance | tuple: if self.full_name == "builtins.tuple" and args.is_empty(): value = _instances.Tuple((), self.ctx) else: @@ -577,10 +585,10 @@ def _new_instance(self, container, node, args): def instantiate(self, node, container=None): return self.ctx.convert.pytd_cls_to_instance_var(self.pytd_cls, {}, node) - def __repr__(self): + def __repr__(self) -> str: return f"PyTDClass({self.name})" - def __contains__(self, name): + def __contains__(self, name) -> bool: return name in self._member_map def convert_as_instance_attribute(self, name, instance): @@ -624,7 +632,7 @@ def convert_as_instance_attribute(self, name, instance): subst[name] = self.ctx.new_unsolvable(self.ctx.root_node) return self._convert_member(name, c, subst) - def has_protocol_base(self): + def has_protocol_base(self) -> bool: for base in self.pytd_cls.bases: if base.name == "typing.Protocol": return True @@ -640,7 +648,7 @@ class FunctionPyTDClass(PyTDClass): save the value of `func`, not just its type of Callable. """ - def __init__(self, func, ctx): + def __init__(self, func, ctx) -> None: super().__init__("typing.Callable", ctx.convert.function_type.pytd_cls, ctx) self.func = func @@ -694,12 +702,12 @@ def __init__( mixin.NestedAnnotation.init_mixin(self) self._type_param_check() - def __repr__(self): + def __repr__(self) -> str: return "ParameterizedClass(cls={!r} params={})".format( self.base_cls, self._formal_type_parameters ) - def _type_param_check(self): + def _type_param_check(self) -> None: """Throw exception for invalid type parameters.""" # It will cause infinite recursion if `formal_type_parameters` is # `LazyFormalTypeParameters` @@ -715,14 +723,14 @@ def get_formal_type_parameters(self): for k, v in self.formal_type_parameters.items() } - def __eq__(self, other): + def __eq__(self, other) -> bool: if isinstance(other, type(self)): return self.base_cls == other.base_cls and ( self.formal_type_parameters == other.formal_type_parameters ) return NotImplemented - def __ne__(self, other): + def __ne__(self, other) -> bool: return not self == other def __hash__(self): @@ -750,10 +758,12 @@ def __hash__(self): hashval = self._hash return hashval - def __contains__(self, name): + def __contains__(self, name) -> bool: return name in self.base_cls - def _raw_formal_type_parameters(self): + def _raw_formal_type_parameters( + self, + ) -> Generator[tuple[Any, Any], Any, None]: assert isinstance( self._formal_type_parameters, abstract_utils.LazyFormalTypeParameters ) @@ -777,7 +787,7 @@ def formal_type_parameters(self) -> dict[str | int, _base.BaseValue]: self._load_formal_type_parameters() return self._formal_type_parameters # pytype: disable=bad-return-type - def _load_formal_type_parameters(self): + def _load_formal_type_parameters(self) -> None: if self._formal_type_parameters_loaded: return if isinstance( @@ -802,7 +812,7 @@ def _load_formal_type_parameters(self): ) self._formal_type_parameters_loaded = True - def compute_mro(self): + def compute_mro(self) -> tuple: return (self,) + self.base_cls.mro[1:] def instantiate(self, node, container=None): @@ -831,7 +841,7 @@ def cls(self): def cls(self, cls): self._cls = cls - def set_class(self, node, var): + def set_class(self, node, var) -> None: self.base_cls.set_class(node, var) @property @@ -842,7 +852,7 @@ def official_name(self): def official_name(self, official_name): self.base_cls.official_name = official_name - def _is_callable(self): + def _is_callable(self) -> bool: if not isinstance(self.base_cls, (InterpreterClass, PyTDClass)): # We don't know how to instantiate this base_cls. return False @@ -860,7 +870,7 @@ def _is_callable(self): # the side of allowing such calls. return not self.is_abstract - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: if not self._is_callable(): raise error_types.NotCallable(self) else: @@ -872,10 +882,10 @@ def get_formal_type_parameter(self, t): def get_inner_types(self): return self.formal_type_parameters.items() - def update_inner_type(self, key, typ): + def update_inner_type(self, key, typ) -> None: self.formal_type_parameters[key] = typ - def replace(self, inner_types): + def replace(self: _TParameterizedClass, inner_types) -> _TParameterizedClass: inner_types = dict(inner_types) if isinstance(self, LiteralClass): if inner_types == self.formal_type_parameters: @@ -904,17 +914,19 @@ class CallableClass(ParameterizedClass, mixin.HasSlots): # pytype: disable=sign When there are no args (CallableClass[[], ...]), ARGS contains abstract.Empty. """ - def __init__(self, base_cls, formal_type_parameters, ctx, template=None): + def __init__( + self, base_cls, formal_type_parameters, ctx, template=None + ) -> None: super().__init__(base_cls, formal_type_parameters, ctx, template) mixin.HasSlots.init_mixin(self) self.set_native_slot("__call__", self.call_slot) # We subtract two to account for "ARGS" and "RET". self.num_args = len(self.formal_type_parameters) - 2 - def __repr__(self): + def __repr__(self) -> str: return f"CallableClass({self.formal_type_parameters})" - def get_formal_type_parameters(self): + def get_formal_type_parameters(self) -> dict[Any, _base.BaseValue]: return { abstract_utils.full_type_name( self, abstract_utils.ARGS @@ -924,7 +936,7 @@ def get_formal_type_parameters(self): ): self.formal_type_parameters[abstract_utils.RET], } - def call_slot(self, node, *args, **kwargs): + def call_slot(self, node: _T0, *args, **kwargs) -> tuple[_T0, Any]: """Implementation of CallableClass.__call__.""" if kwargs: raise error_types.WrongKeywordArgs( @@ -983,7 +995,7 @@ def get_args(self): """Get the callable's posargs as a list.""" return [self.formal_type_parameters[i] for i in range(self.num_args)] - def has_paramspec(self): + def has_paramspec(self) -> bool: return _isinstance( self.formal_type_parameters[abstract_utils.ARGS], ("ParamSpec", "Concatenate"), @@ -993,13 +1005,13 @@ def has_paramspec(self): class LiteralClass(ParameterizedClass): """The class of a typing.Literal.""" - def __init__(self, instance, ctx, template=None): + def __init__(self, instance, ctx, template=None) -> None: base_cls = ctx.convert.lookup_value("typing", "Literal") formal_type_parameters = {abstract_utils.T: instance.cls} super().__init__(base_cls, formal_type_parameters, ctx, template) self._instance = instance - def __repr__(self): + def __repr__(self) -> str: return f"LiteralClass({self._instance})" def __eq__(self, other): @@ -1012,7 +1024,7 @@ def __eq__(self, other): return self.value == other.value return super().__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash((super().__hash__(), self._instance)) @property @@ -1035,7 +1047,9 @@ class TupleClass(ParameterizedClass, mixin.HasSlots): # pytype: disable=signatu do for Tuple, since we can't evaluate type parameters during initialization. """ - def __init__(self, base_cls, formal_type_parameters, ctx, template=None): + def __init__( + self, base_cls, formal_type_parameters, ctx, template=None + ) -> None: super().__init__(base_cls, formal_type_parameters, ctx, template) mixin.HasSlots.init_mixin(self) self.set_native_slot("__getitem__", self.getitem_slot) @@ -1052,15 +1066,15 @@ def __init__(self, base_cls, formal_type_parameters, ctx, template=None): self._instance_cache = {} self.slots = () # tuples don't have any writable attributes - def __repr__(self): + def __repr__(self) -> str: return f"TupleClass({self.formal_type_parameters})" - def compute_mro(self): + def compute_mro(self) -> tuple: # ParameterizedClass removes the base PyTDClass(tuple) from the mro; add it # back here so that isinstance(tuple) checks work. return (self,) + self.base_cls.mro - def get_formal_type_parameters(self): + def get_formal_type_parameters(self) -> dict[Any, _base.BaseValue]: return { abstract_utils.full_type_name( self, abstract_utils.T @@ -1101,7 +1115,7 @@ def _instantiate_index(self, node, index): index %= self.tuple_length # fixes negative indices return self.formal_type_parameters[index].instantiate(node) - def register_instance(self, instance): + def register_instance(self, instance) -> None: # A TupleClass can never have more than one registered instance because the # only direct instances of TupleClass are Tuple objects, which create their # own class upon instantiation. We store the instance in order to track @@ -1109,7 +1123,7 @@ def register_instance(self, instance): assert not self._instance self._instance = instance - def getitem_slot(self, node, index_var): + def getitem_slot(self, node: _T0, index_var) -> tuple[Any, Any]: """Implementation of tuple.__getitem__.""" try: index = self.ctx.convert.value_to_constant( @@ -1155,7 +1169,7 @@ def get_special_attribute(self, node, name, valself): return mixin.HasSlots.get_special_attribute(self, node, name, valself) return super().get_special_attribute(node, name, valself) - def add_slot(self, node, other_var): + def add_slot(self, node: _T0, other_var) -> tuple[Any, Any]: """Implementation of tuple.__add__.""" try: other = abstract_utils.get_atomic_value(other_var) diff --git a/pytype/abstract/_function_base.py b/pytype/abstract/_function_base.py index b47dc497e..309c26467 100644 --- a/pytype/abstract/_function_base.py +++ b/pytype/abstract/_function_base.py @@ -1,9 +1,11 @@ """Base abstract representations of functions.""" +from collections.abc import Callable, Generator import contextlib import inspect import itertools import logging +from typing import Any from pytype.abstract import _base from pytype.abstract import _classes @@ -16,7 +18,7 @@ from pytype.errors import error_types from pytype.types import types -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) _isinstance = abstract_utils._isinstance # pylint: disable=protected-access @@ -30,7 +32,7 @@ class Function(_instance_base.SimpleValue, types.Function): bound_class: type["BoundFunction"] - def __init__(self, name, ctx): + def __init__(self, name, ctx) -> None: super().__init__(name, ctx) self.cls = _classes.FunctionPyTDClass(self, ctx) self.is_attribute_of_class = False @@ -43,7 +45,9 @@ def __init__(self, name, ctx): self.ctx.root_node, name ) - def property_get(self, callself, is_class=False): + def property_get( + self, callself, is_class=False + ) -> "BoundFunction|Function": if self.name == "__new__" or not callself or is_class: return self self.is_attribute_of_class = True @@ -122,7 +126,7 @@ def _extract_defaults(self, defaults_var): def set_function_defaults(self, node, defaults_var): raise NotImplementedError(self.__class__.__name__) - def update_signature_scope(self, cls): + def update_signature_scope(self, cls) -> None: return @@ -135,7 +139,7 @@ class NativeFunction(Function): ctx: context.Context instance. """ - def __init__(self, name, func, ctx): + def __init__(self, name, func, ctx) -> None: super().__init__(name, ctx) self.func = func self.bound_class = lambda callself, underlying: self @@ -203,7 +207,7 @@ def call(self, node, func, args, alias_map=None): raise error_types.DuplicateKeyword(sig, args, self.ctx, "self") return self.func(node, *posargs, **namedargs) - def get_positional_names(self): + def get_positional_names(self) -> list: code = self.func.func_code return list(code.varnames[: code.argcount]) @@ -214,7 +218,7 @@ def property_get(self, callself, is_class=False): class BoundFunction(_base.BaseValue): """An function type which has had an argument bound into it.""" - def __init__(self, callself, underlying): + def __init__(self, callself, underlying) -> None: super().__init__(underlying.name, underlying.ctx) self.cls = _classes.FunctionPyTDClass(self, self.ctx) self._callself = callself @@ -262,7 +266,7 @@ def signature(self): def callself(self): return self._callself - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: if self.name.endswith(".__init__"): self.ctx.callself_stack.append(self._callself) # The "self" parameter is automatically added to the list of arguments, but @@ -360,7 +364,7 @@ class BoundInterpreterFunction(BoundFunction): """The method flavor of InterpreterFunction.""" @contextlib.contextmanager - def record_calls(self): + def record_calls(self) -> contextlib._GeneratorContextManager: with self.underlying.record_calls(): yield @@ -383,7 +387,7 @@ def is_overload(self, value): def defaults(self): return self.underlying.defaults - def iter_signature_functions(self): + def iter_signature_functions(self) -> Generator[Any, Any, None]: for f in self.underlying.iter_signature_functions(): yield self.underlying.bound_class(self._callself, f) @@ -398,7 +402,7 @@ class BoundPyTDFunction(BoundFunction): class ClassMethod(_base.BaseValue): """Implements @classmethod methods in pyi.""" - def __init__(self, name, method, callself, ctx): + def __init__(self, name, method, callself, ctx) -> None: super().__init__(name, ctx) self.cls = self.ctx.convert.function_type self.method = method @@ -412,14 +416,14 @@ def call(self, node, func, args, alias_map=None): node, func, args.replace(posargs=(self._callcls,) + args.posargs) ) - def to_bound_function(self): + def to_bound_function(self) -> BoundPyTDFunction: return BoundPyTDFunction(self._callcls, self.method) class StaticMethod(_base.BaseValue): """Implements @staticmethod methods in pyi.""" - def __init__(self, name, method, _, ctx): + def __init__(self, name, method, _, ctx) -> None: super().__init__(name, ctx) self.cls = self.ctx.convert.function_type self.method = method @@ -436,7 +440,7 @@ class Property(_base.BaseValue): resolved as a function, not as a constant. """ - def __init__(self, name, method, callself, ctx): + def __init__(self, name, method, callself, ctx) -> None: super().__init__(name, ctx) self.cls = self.ctx.convert.function_type self.method = method @@ -455,7 +459,7 @@ class SignedFunction(Function): Subclasses should define call(self, node, f, args) and set self.bound_class. """ - def __init__(self, signature, ctx): + def __init__(self, signature, ctx) -> None: # We should only instantiate subclasses of SignedFunction assert self.__class__ != SignedFunction super().__init__(signature.name, ctx) @@ -495,7 +499,7 @@ def get_self_type_param(self): return param return None - def argcount(self, _): + def argcount(self, _) -> int: return len(self.signature.param_names) def get_nondefault_params(self): @@ -505,7 +509,7 @@ def get_nondefault_params(self): if n not in self.signature.defaults ) - def match_and_map_args(self, node, args, alias_map): + def match_and_map_args(self, node, args, alias_map) -> tuple[Any, Any]: """Calls match_args() and _map_args().""" return self.match_args(node, args, alias_map), self._map_args(node, args) @@ -635,10 +639,10 @@ def _match_args_sequentially(self, node, args, alias_map, match_all_views): ) return [m.subst for m in matches] - def get_first_opcode(self): + def get_first_opcode(self) -> None: return None - def set_function_defaults(self, node, defaults_var): + def set_function_defaults(self, node, defaults_var) -> None: """Attempts to set default arguments of a function. If defaults_var is not an unambiguous tuple (i.e. one that can be processed @@ -659,7 +663,7 @@ def set_function_defaults(self, node, defaults_var): defaults = dict(zip(self.signature.param_names[-len(defaults) :], defaults)) self.signature.defaults = defaults - def _mutations_generator(self, node, first_arg, substs): + def _mutations_generator(self, node, first_arg, substs) -> Callable[[], Any]: def generator(): """Yields mutations.""" if ( @@ -697,7 +701,7 @@ def generator(): # extra time. return generator - def update_signature_scope(self, cls): + def update_signature_scope(self, cls) -> None: self.signature.excluded_types.update([t.name for t in cls.template]) self.signature.add_scope(cls) @@ -709,7 +713,7 @@ class SimpleFunction(SignedFunction): record calls or try to infer types. """ - def __init__(self, signature, ctx): + def __init__(self, signature, ctx) -> None: super().__init__(signature, ctx) self.bound_class = BoundFunction @@ -783,7 +787,7 @@ def _skip_parameter_matching(self): return self.signature.has_return_annotation or self.full_name == "__init__" - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: args = args.simplify(node, self.ctx) callargs = self._map_args(node, args) substs = [] diff --git a/pytype/abstract/_instance_base.py b/pytype/abstract/_instance_base.py index 5a829ae69..41e68d198 100644 --- a/pytype/abstract/_instance_base.py +++ b/pytype/abstract/_instance_base.py @@ -1,6 +1,7 @@ """Abstract representation of instances.""" import logging +from typing import Any, TypeVar from pytype import datatypes from pytype.abstract import _base @@ -9,7 +10,9 @@ from pytype.abstract import function from pytype.errors import error_types -log = logging.getLogger(__name__) +_T0 = TypeVar("_T0") + +log: logging.Logger = logging.getLogger(__name__) _isinstance = abstract_utils._isinstance # pylint: disable=protected-access @@ -47,7 +50,7 @@ def __init__(self, name, ctx): self._fullhash = None self._cached_changestamps = self._get_changestamps() - def _get_changestamps(self): + def _get_changestamps(self) -> tuple[Any, Any]: return ( self.members.changestamp, self._instance_type_parameters.changestamp, @@ -73,7 +76,7 @@ def maybe_missing_members(self): def maybe_missing_members(self, v): self._maybe_missing_members = v - def has_instance_type_parameter(self, name): + def has_instance_type_parameter(self, name) -> bool: """Check if the key is in `instance_type_parameters`.""" name = abstract_utils.full_type_name(self, name) return name in self.instance_type_parameters @@ -89,7 +92,7 @@ def get_instance_type_parameter(self, name, node=None): self.instance_type_parameters[name] = param return param - def merge_instance_type_parameter(self, node, name, value): + def merge_instance_type_parameter(self, node, name, value) -> None: """Set the value of a type parameter. This will always add to the type parameter unlike set_attribute which will @@ -109,7 +112,7 @@ def merge_instance_type_parameter(self, node, name, value): else: self.instance_type_parameters[name] = value - def _call_helper(self, node, obj, binding, args): + def _call_helper(self, node, obj, binding, args) -> tuple[Any, Any]: obj_binding = binding if obj == binding.data else obj.to_binding(node) node, var = self.ctx.attribute_handler.get_attribute( node, obj, "__call__", obj_binding @@ -133,7 +136,7 @@ def argcount(self, node): # value will lead to a not-callable error anyways. return 0 - def __repr__(self): + def __repr__(self) -> str: return f"<{self.name} [{self.cls!r}]>" def _get_class(self): @@ -152,7 +155,7 @@ def cls(self): def cls(self, cls): self._cls = cls - def set_class(self, node, var): + def set_class(self, node: _T0, var) -> _T0: """Set the __class__ of an instance, for code that does "x.__class__ = y.""" # Simplification: Setting __class__ is done rarely, and supporting this # action would complicate pytype considerably by forcing us to track the @@ -166,7 +169,7 @@ def set_class(self, node, var): self.cls = self.ctx.convert.unsolvable return node - def update_caches(self, force=False): + def update_caches(self, force=False) -> None: cur_changestamps = self._get_changestamps() if self._cached_changestamps == cur_changestamps and not force: return @@ -205,7 +208,7 @@ def get_type_key(self, seen=None): self._type_key = frozenset(key) return self._type_key - def _unique_parameters(self): + def _unique_parameters(self) -> list: parameters = super()._unique_parameters() parameters.extend(self.instance_type_parameters.values()) return parameters @@ -217,14 +220,14 @@ def instantiate(self, node, container=None): class Instance(SimpleValue): """An instance of some object.""" - def __init__(self, cls, ctx, container=None): + def __init__(self, cls, ctx, container=None) -> None: super().__init__(cls.name, ctx) self.cls = cls self._instance_type_parameters_loaded = False self._container = container cls.register_instance(self) - def _load_instance_type_parameters(self): + def _load_instance_type_parameters(self) -> None: if self._instance_type_parameters_loaded: return all_formal_type_parameters = datatypes.AliasingDict() diff --git a/pytype/abstract/_instances.py b/pytype/abstract/_instances.py index df1574654..52dc765db 100644 --- a/pytype/abstract/_instances.py +++ b/pytype/abstract/_instances.py @@ -1,8 +1,9 @@ """Specialized instance representations.""" +from collections.abc import Generator as _Generator import contextlib import logging -from typing import Union +from typing import Any, TypeVar, Union from pytype.abstract import _base from pytype.abstract import _instance_base @@ -15,7 +16,9 @@ from pytype.typegraph import cfg_utils from pytype.types import types -log = logging.getLogger(__name__) +_T0 = TypeVar("_T0") + +log: logging.Logger = logging.getLogger(__name__) _make = abstract_utils._make # pylint: disable=protected-access @@ -46,7 +49,7 @@ class LazyConcreteDict( ): """Dictionary with lazy values.""" - def __init__(self, name, member_map, ctx): + def __init__(self, name, member_map, ctx) -> None: super().__init__(name, ctx) mixin.PythonConstant.init_mixin(self, self.members) mixin.LazyMembers.init_mixin(self, member_map) @@ -54,25 +57,25 @@ def __init__(self, name, member_map, ctx): def _convert_member(self, name, member, subst=None): return self.ctx.convert.constant_to_var(member) - def is_empty(self): + def is_empty(self) -> bool: return not bool(self._member_map) class ConcreteValue(_instance_base.Instance, mixin.PythonConstant): """Abstract value with a concrete fallback.""" - def __init__(self, pyval, cls, ctx): + def __init__(self, pyval, cls, ctx) -> None: super().__init__(cls, ctx) mixin.PythonConstant.init_mixin(self, pyval) - def get_fullhash(self, seen=None): + def get_fullhash(self, seen=None) -> int: return hash((type(self), id(self.pyval))) class Module(_instance_base.Instance, mixin.LazyMembers, types.Module): """Represents an (imported) module.""" - def __init__(self, ctx, name, member_map, ast): + def __init__(self, ctx, name, member_map, ast) -> None: super().__init__(ctx.convert.module_type, ctx) self.name = name self.ast = ast @@ -109,7 +112,7 @@ def module(self, m): def full_name(self): return self.ast.name - def has_getattr(self): + def has_getattr(self) -> bool: """Does this module have a module-level __getattr__? We allow __getattr__ on the module level to specify that this module doesn't @@ -148,12 +151,12 @@ def get_submodule(self, node, name): log.warning("Couldn't find attribute / module %r", full_name) return None - def items(self): + def items(self) -> list[tuple[Any, Any]]: for name in self._member_map: self.load_lazy_attribute(name) return list(self.members.items()) - def get_fullhash(self, seen=None): + def get_fullhash(self, seen=None) -> int: """Hash the set of member names.""" return hash((type(self), self.full_name) + tuple(sorted(self._member_map))) @@ -161,7 +164,7 @@ def get_fullhash(self, seen=None): class Coroutine(_instance_base.Instance): """A representation of instances of coroutine.""" - def __init__(self, ctx, ret_var, node): + def __init__(self, ctx, ret_var, node) -> None: super().__init__(ctx.convert.coroutine_type, ctx) self.merge_instance_type_parameter( node, abstract_utils.T, self.ctx.new_unsolvable(node) @@ -177,26 +180,26 @@ def __init__(self, ctx, ret_var, node): class Iterator(_instance_base.Instance, mixin.HasSlots): """A representation of instances of iterators.""" - def __init__(self, ctx, return_var): + def __init__(self, ctx, return_var) -> None: super().__init__(ctx.convert.iterator_type, ctx) mixin.HasSlots.init_mixin(self) self.set_native_slot("__next__", self.next_slot) self._return_var = return_var - def next_slot(self, node): + def next_slot(self, node: _T0) -> tuple[_T0, Any]: return node, self._return_var class BaseGenerator(_instance_base.Instance): """A base class of instances of generators and async generators.""" - def __init__(self, generator_type, frame, ctx, is_return_allowed): + def __init__(self, generator_type, frame, ctx, is_return_allowed) -> None: super().__init__(generator_type, ctx) self.frame = frame self.runs = 0 self.is_return_allowed = is_return_allowed # if return statement is allowed - def run_generator(self, node): + def run_generator(self, node) -> tuple[Any, Any]: """Run the generator.""" if self.runs == 0: # Optimization: We only run it once. node, _ = self.ctx.vm.resume_frame(node, self.frame) @@ -238,7 +241,7 @@ def call(self, node, func, args, alias_map=None): class AsyncGenerator(BaseGenerator): """A representation of instances of async generators.""" - def __init__(self, async_generator_frame, ctx): + def __init__(self, async_generator_frame, ctx) -> None: super().__init__( ctx.convert.async_generator_type, async_generator_frame, ctx, False ) @@ -247,7 +250,7 @@ def __init__(self, async_generator_frame, ctx): class Generator(BaseGenerator): """A representation of instances of generators.""" - def __init__(self, generator_frame, ctx): + def __init__(self, generator_frame, ctx) -> None: super().__init__(ctx.convert.generator_type, generator_frame, ctx, True) def get_special_attribute(self, node, name, valself): @@ -264,14 +267,14 @@ def get_special_attribute(self, node, name, valself): else: return super().get_special_attribute(node, name, valself) - def __iter__(self, node): # pylint: disable=non-iterator-returned,unexpected-special-method-signature + def __iter__(self, node: _T0) -> tuple[_T0, Any]: # pylint: disable=non-iterator-returned,unexpected-special-method-signature return node, self.to_variable(node) class Tuple(_instance_base.Instance, mixin.PythonConstant): """Representation of Python 'tuple' objects.""" - def __init__(self, content, ctx): + def __init__(self, content, ctx) -> None: combined_content = ctx.convert.build_content(content) class_params = { name: ctx.convert.merge_classes(instance_param.data) @@ -287,7 +290,7 @@ def __init__(self, content, ctx): # set this to true when creating a function arg tuple self.is_unpacked_function_args = False - def str_of_constant(self, printer): + def str_of_constant(self, printer) -> str: content = ", ".join( " or ".join(_var_map(printer, val)) for val in self.pyval ) @@ -295,12 +298,12 @@ def str_of_constant(self, printer): content += "," return f"({content})" - def _unique_parameters(self): + def _unique_parameters(self) -> list: parameters = super()._unique_parameters() parameters.extend(self.pyval) return parameters - def _is_recursive(self): + def _is_recursive(self) -> bool: """True if the tuple contains itself.""" return any(any(x is self for x in e.data) for e in self.pyval) @@ -335,7 +338,7 @@ def get_fullhash(self, seen=None): class List(_instance_base.Instance, mixin.HasSlots, mixin.PythonConstant): # pytype: disable=signature-mismatch """Representation of Python 'list' objects.""" - def __init__(self, content, ctx): + def __init__(self, content, ctx) -> None: super().__init__(ctx.convert.list_type, ctx) self._instance_cache = {} combined_content = ctx.convert.build_content(content) @@ -345,12 +348,12 @@ def __init__(self, content, ctx): self.set_native_slot("__getitem__", self.getitem_slot) self.set_native_slot("__getslice__", self.getslice_slot) - def str_of_constant(self, printer): + def str_of_constant(self, printer) -> str: return "[%s]" % ", ".join( " or ".join(_var_map(printer, val)) for val in self.pyval ) - def __repr__(self): + def __repr__(self) -> str: if self.is_concrete: return mixin.PythonConstant.__repr__(self) else: @@ -361,11 +364,11 @@ def get_fullhash(self, seen=None): return _get_concrete_sequence_fullhash(self, seen) return super().get_fullhash(seen) - def merge_instance_type_parameter(self, node, name, value): + def merge_instance_type_parameter(self, node, name, value) -> None: self.is_concrete = False super().merge_instance_type_parameter(node, name, value) - def getitem_slot(self, node, index_var): + def getitem_slot(self, node, index_var) -> tuple[Any, Any]: """Implements __getitem__ for List. Arguments: @@ -419,7 +422,7 @@ def _get_index(self, data): else: raise abstract_utils.ConversionError() - def getslice_slot(self, node, start_var, end_var): + def getslice_slot(self, node, start_var, end_var) -> tuple[Any, Any]: """Implements __getslice__ for List. Arguments: @@ -460,7 +463,7 @@ class Dict(_instance_base.Instance, mixin.HasSlots, mixin.PythonDict): of what got stored. """ - def __init__(self, ctx): + def __init__(self, ctx) -> None: super().__init__(ctx.convert.dict_type, ctx) mixin.HasSlots.init_mixin(self) self.set_native_slot("__contains__", self.contains_slot) @@ -473,7 +476,7 @@ def __init__(self, ctx): # For example: f_locals["__annotations__"] mixin.PythonDict.init_mixin(self, {}) - def str_of_constant(self, printer): + def str_of_constant(self, printer) -> str: # self.pyval is only populated for string keys. if not self.is_concrete: return "{...: ...}" @@ -483,7 +486,7 @@ def str_of_constant(self, printer): ] return "{" + ", ".join(pairs) + "}" - def __repr__(self): + def __repr__(self) -> str: if not hasattr(self, "is_concrete"): return "Dict (not fully initialized)" elif self.is_concrete: @@ -504,7 +507,7 @@ def get_fullhash(self, seen=None): + abstract_utils.get_dict_fullhash_component(self.pyval, seen=seen) ) - def getitem_slot(self, node, name_var): + def getitem_slot(self, node, name_var) -> tuple[Any, Any]: """Implements the __getitem__ slot.""" results = [] unresolved = False @@ -528,11 +531,11 @@ def getitem_slot(self, node, name_var): results.append(ret) return node, self.ctx.join_variables(node, results) - def merge_instance_type_params(self, node, name_var, value_var): + def merge_instance_type_params(self, node, name_var, value_var) -> None: self.merge_instance_type_parameter(node, abstract_utils.K, name_var) self.merge_instance_type_parameter(node, abstract_utils.V, value_var) - def set_str_item(self, node, name, value_var): + def set_str_item(self, node: _T0, name, value_var) -> _T0: name_var = self.ctx.convert.build_nonatomic_string(node) self.merge_instance_type_params(node, name_var, value_var) if name in self.pyval: @@ -558,7 +561,7 @@ def setitem( else: self.pyval[name] = value_var - def setitem_slot(self, node, name_var, value_var): + def setitem_slot(self, node, name_var, value_var) -> tuple[Any, Any]: """Implements the __setitem__ slot.""" self.setitem(node, name_var, value_var) return self.call_pytd( @@ -568,7 +571,7 @@ def setitem_slot(self, node, name_var, value_var): abstract_utils.abstractify_variable(value_var, self.ctx), ) - def setdefault_slot(self, node, name_var, value_var=None): + def setdefault_slot(self, node, name_var, value_var=None) -> tuple[Any, Any]: if value_var is None: value_var = self.ctx.convert.build_none(node) # We don't have a good way of modelling the exact setdefault behavior - @@ -578,7 +581,7 @@ def setdefault_slot(self, node, name_var, value_var=None): self.setitem(node, name_var, value_var) return self.call_pytd(node, "setdefault", name_var, value_var) - def contains_slot(self, node, key_var): + def contains_slot(self, node: _T0, key_var) -> tuple[_T0, Any]: if self.is_concrete: try: str_key = abstract_utils.get_atomic_python_constant(key_var, str) @@ -590,7 +593,7 @@ def contains_slot(self, node, key_var): value = None return node, self.ctx.convert.build_bool(node, value) - def pop_slot(self, node, key_var, default_var=None): + def pop_slot(self, node, key_var, default_var=None) -> tuple[Any, Any]: try: str_key = abstract_utils.get_atomic_python_constant(key_var, str) except abstract_utils.ConversionError: @@ -608,21 +611,23 @@ def pop_slot(self, node, key_var, default_var=None): except KeyError as e: raise error_types.DictKeyMissing(str_key) from e - def _set_params_to_any(self, node): + def _set_params_to_any(self, node) -> None: self.is_concrete = False unsolvable = self.ctx.new_unsolvable(node) for p in (abstract_utils.K, abstract_utils.V): self.merge_instance_type_parameter(node, p, unsolvable) @contextlib.contextmanager - def _set_params_to_any_on_failure(self, node): + def _set_params_to_any_on_failure( + self, node + ): try: yield except error_types.FailedFunctionCall: self._set_params_to_any(node) raise - def update_slot(self, node, *args, **kwargs): + def update_slot(self, node, *args, **kwargs) -> tuple[Any, Any]: if len(args) == 1 and len(args[0].data) == 1: with self._set_params_to_any_on_failure(node): for f in self._super["update"].data: @@ -663,7 +668,7 @@ def update( class AnnotationsDict(Dict): """__annotations__ dict.""" - def __init__(self, annotated_locals, ctx): + def __init__(self, annotated_locals, ctx) -> None: self.annotated_locals = annotated_locals super().__init__(ctx) @@ -672,13 +677,13 @@ def get_type(self, node, name): return None return self.annotated_locals[name].get_type(node, name) - def get_annotations(self, node): + def get_annotations(self, node) -> _Generator[tuple[Any, Any], Any, None]: for name, local in self.annotated_locals.items(): typ = local.get_type(node, name) if typ: yield name, typ - def __repr__(self): + def __repr__(self) -> str: return repr(self.annotated_locals) @@ -688,7 +693,7 @@ def __repr__(self): class Splat(_base.BaseValue): """Representation of unpacked iterables.""" - def __init__(self, ctx, iterable): + def __init__(self, ctx, iterable) -> None: super().__init__("splat", ctx) # When building a tuple for a function call, we preserve splats as elements # in a concrete tuple (e.g. f(x, *ys, z) gets called with the concrete tuple @@ -700,14 +705,14 @@ def __init__(self, ctx, iterable): self.cls = ctx.convert.unsolvable self.iterable = iterable - def __repr__(self): + def __repr__(self) -> str: return f"splat({self.iterable.data!r})" class SequenceLength(_base.BaseValue, mixin.HasSlots): """Sequence length for match statements.""" - def __init__(self, sequence, ctx): + def __init__(self, sequence, ctx) -> None: super().__init__("SequenceLength", ctx) length = 0 splat = False @@ -721,14 +726,14 @@ def __init__(self, sequence, ctx): mixin.HasSlots.init_mixin(self) self.set_native_slot("__sub__", self.sub_slot) - def __repr__(self): + def __repr__(self) -> str: splat = "+" if self.splat else "" return f"SequenceLength[{self.length}{splat}]" def instantiate(self, node, container=None): return self.to_variable(node) - def sub_slot(self, node, other_var): + def sub_slot(self, node: _T0, other_var) -> tuple[_T0, Any]: # We should not get a ConversionError here; this is code generated by the # compiler from a literal sequence in a concrete match statement val = abstract_utils.get_atomic_python_constant(other_var, int) diff --git a/pytype/abstract/_interpreter_function.py b/pytype/abstract/_interpreter_function.py index 008757218..fd830ec23 100644 --- a/pytype/abstract/_interpreter_function.py +++ b/pytype/abstract/_interpreter_function.py @@ -1,10 +1,12 @@ """Abstract representation of functions defined in the module under analysis.""" import collections +from collections.abc import Generator import contextlib import hashlib import itertools import logging +from typing import Any, TypeVar from pytype.abstract import _classes from pytype.abstract import _function_base @@ -17,13 +19,17 @@ from pytype.errors import error_types from pytype.pytd import pytd from pytype.pytd import pytd_utils +from pytype.pytd.pytd import ParameterKind from pytype.typegraph import cfg_utils -log = logging.getLogger(__name__) +_T0 = TypeVar("_T0") +_T2 = TypeVar("_T2") + +log: logging.Logger = logging.getLogger(__name__) _isinstance = abstract_utils._isinstance # pylint: disable=protected-access -def _matches_generator_helper(type_obj, allowed_types): +def _matches_generator_helper(type_obj, allowed_types) -> bool: """Check if type_obj matches a Generator/AsyncGenerator type.""" if isinstance(type_obj, _typing.Union): return all( @@ -50,7 +56,7 @@ def _matches_async_generator(type_obj): return _matches_generator_helper(type_obj, allowed_types) -def _hash_all_dicts(*hash_args): +def _hash_all_dicts(*hash_args) -> bytes: """Convenience method for hashing a sequence of dicts.""" components = ( abstract_utils.get_dict_fullhash_component(d, names=n) @@ -61,7 +67,7 @@ def _hash_all_dicts(*hash_args): ).digest() -def _check_classes(var, check): +def _check_classes(var, check) -> bool: """Check whether the cls of each value in `var` is a class and passes `check`. Args: @@ -198,7 +204,7 @@ def __init__( annotations, overloads, ctx, - ): + ) -> None: log.debug("Creating InterpreterFunction %r for %r", name, code.name) self.bound_class = _function_base.BoundInterpreterFunction self.doc = code.consts[0] if code.consts else None @@ -244,14 +250,14 @@ def __init__( self.cache_return = False @contextlib.contextmanager - def record_calls(self): + def record_calls(self) -> contextlib._GeneratorContextManager: """Turn on recording of function calls. Used by analyze.py.""" old = self._store_call_records self._store_call_records = True yield self._store_call_records = old - def _check_signature(self): + def _check_signature(self) -> None: """Validate function signature.""" for ann in self.signature.annotations.values(): if isinstance(ann, _typing.FinalAnnotation): @@ -308,7 +314,7 @@ def _check_signature(self): f"{input_pytd}", ) - def _build_signature(self, name, annotations): + def _build_signature(self, name, annotations) -> function.Signature: """Build a function.Signature object representing this function.""" vararg_name = None kwarg_name = None @@ -335,7 +341,7 @@ def _build_signature(self, name, annotations): annotations, ) - def _update_signature_scope_from_closure(self): + def _update_signature_scope_from_closure(self) -> None: # If this is a nested function in an instance method and the nested function # accesses 'self', then the first variable in the closure is 'self'. We use # 'self' to update the scopes of any type parameters in the nested method's @@ -363,7 +369,7 @@ def match_args(self, node, args, alias_map=None, match_all_views=False): return return super().match_args(node, args, alias_map, match_all_views) - def _inner_cls_check(self, last_frame): + def _inner_cls_check(self, last_frame) -> None: """Check if the function and its nested class use same type parameter.""" # get all type parameters from function annotations all_type_parameters = [] @@ -403,7 +409,7 @@ def signature_functions(self): """Get the functions that describe this function's signature.""" return self._active_overloads or [self] - def iter_signature_functions(self): + def iter_signature_functions(self) -> Generator[Any, Any, None]: """Loop through signatures, setting each as the primary one in turn.""" if not self._all_overloads: yield self @@ -417,7 +423,7 @@ def iter_signature_functions(self): self._active_overloads = old_overloads @contextlib.contextmanager - def reset_overloads(self): + def reset_overloads(self) -> contextlib._GeneratorContextManager: if self._all_overloads == self._active_overloads: yield return @@ -428,7 +434,7 @@ def reset_overloads(self): finally: self._active_overloads = old_overloads - def _find_matching_sig(self, node, args, alias_map): + def _find_matching_sig(self, node, args, alias_map) -> tuple[Any, Any, Any]: error = None for f in self.signature_functions(): try: @@ -443,12 +449,12 @@ def _find_matching_sig(self, node, args, alias_map): return f.signature, substs, callargs raise error # pylint: disable=raising-bad-type - def _set_callself_maybe_missing_members(self): + def _set_callself_maybe_missing_members(self) -> None: if self.ctx.callself_stack: for b in self.ctx.callself_stack[-1].bindings: b.data.maybe_missing_members = True - def _is_unannotated_contextmanager_exit(self, func, args): + def _is_unannotated_contextmanager_exit(self, func, args) -> bool: """Returns whether this is an unannotated contextmanager __exit__ method. If this is a bound method named __exit__ that has no type annotations and is @@ -472,7 +478,9 @@ def _is_unannotated_contextmanager_exit(self, func, args): and not args.starstarargs ) - def _fix_args_for_unannotated_contextmanager_exit(self, node, func, args): + def _fix_args_for_unannotated_contextmanager_exit( + self, node, func, args: _T2 + ) -> function.Args | _T2: """Adjust argument types for a contextmanager's __exit__ method.""" if not self._is_unannotated_contextmanager_exit(func.data, args): return args @@ -541,7 +549,7 @@ def _paramspec_signature(self, callable_type, substs): pspec_match, r_args, return_value, self.ctx ) - def _handle_paramspec(self, sig, annotations, substs, callargs): + def _handle_paramspec(self, sig, annotations, substs, callargs) -> None: if not sig.has_return_annotation: return retval = sig.annotations["return"] @@ -559,8 +567,14 @@ def _handle_paramspec(self, sig, annotations, substs, callargs): annotations[name] = param_annot def call( - self, node, func, args, alias_map=None, new_locals=False, frame_substs=() - ): + self, + node, + func, + args, + alias_map=None, + new_locals=False, + frame_substs=(), + ) -> tuple[Any, Any]: if self.is_overload: raise error_types.NotCallable(self) args = self._fix_args_for_unannotated_contextmanager_exit(node, func, args) @@ -755,7 +769,7 @@ def call( self.last_frame = frame return node_after_call, typeguard_return or ret - def get_call_combinations(self, node): + def get_call_combinations(self, node: _T0) -> list[tuple[_T0, Any, Any]]: """Get this function's call records.""" all_combinations = [] signature_data = set() @@ -801,17 +815,19 @@ def get_call_combinations(self, node): all_combinations.append((node, params, ret)) return all_combinations - def get_positional_names(self): + def get_positional_names(self) -> list: return list(self.code.varnames[: self.code.argcount]) - def get_nondefault_params(self): + def get_nondefault_params(self) -> Generator[tuple[Any, bool], Any, None]: for i in range(self.nonstararg_count): yield self.code.varnames[i], i >= self.code.argcount - def get_kwonly_names(self): + def get_kwonly_names(self) -> list: return list(self.code.varnames[self.code.argcount : self.nonstararg_count]) - def get_parameters(self): + def get_parameters( + self, + ) -> Generator[tuple[Any, ParameterKind, bool], Any, None]: default_pos = self.code.argcount - len(self.defaults) i = 0 for name in self.get_positional_names(): @@ -831,7 +847,9 @@ def has_varargs(self): def has_kwargs(self): return self.code.has_varkeywords() - def property_get(self, callself, is_class=False): + def property_get( + self, callself, is_class=False + ) -> _function_base.BoundFunction | _function_base.Function: if self.name.endswith(".__init__") and self.signature.param_names: self_name = self.signature.param_names[0] # If `_has_self_annot` is True, then we've intentionally temporarily @@ -854,7 +872,7 @@ def is_coroutine(self): def is_unannotated_coroutine(self): return self.is_coroutine() and not self.signature.has_return_annotation - def has_empty_body(self): + def has_empty_body(self) -> bool: # TODO(mdemello): Optimise this. ops = list(self.code.code_iter) if self.ctx.python_version >= (3, 12): @@ -885,7 +903,7 @@ def get_self_type_param(self): return None @contextlib.contextmanager - def set_self_annot(self, annot_class): + def set_self_annot(self, annot_class) -> contextlib._GeneratorContextManager: if self.is_overload or not self._active_overloads: with super().set_self_annot(annot_class): yield diff --git a/pytype/abstract/_pytd_function.py b/pytype/abstract/_pytd_function.py index d5ff0955a..e2acc8de0 100644 --- a/pytype/abstract/_pytd_function.py +++ b/pytype/abstract/_pytd_function.py @@ -4,7 +4,7 @@ import contextlib import itertools import logging -from typing import Any +from typing import Any, TypeVar from pytype import datatypes from pytype import utils @@ -25,11 +25,13 @@ from pytype.typegraph import cfg from pytype.types import types -log = logging.getLogger(__name__) +_T0 = TypeVar("_T0") + +log: logging.Logger = logging.getLogger(__name__) _isinstance = abstract_utils._isinstance # pylint: disable=protected-access # pytype.matcher.GoodMatch, which can't be imported due to a circular dep -_GoodMatchType = Any +_GoodMatchType: Any = Any class SignatureMutationError(Exception): @@ -48,7 +50,7 @@ def _is_literal(annot: _base.BaseValue | None): class _MatchedSignatures: """Function call matches.""" - def __init__(self, args, can_match_multiple): + def __init__(self, args, can_match_multiple) -> None: self._args_vars = set(args.get_variables()) self._can_match_multiple = can_match_multiple self._data: list[ @@ -56,7 +58,7 @@ def __init__(self, args, can_match_multiple): ] = [] self._sig = self._cur_data = None - def __bool__(self): + def __bool__(self) -> bool: return bool(self._data) @contextlib.contextmanager @@ -73,7 +75,7 @@ def with_signature(self, sig): self._data.extend(self._cur_data) self._sig = self._cur_data = None - def add(self, arg_dict, match): + def add(self, arg_dict, match) -> None: """Adds a new match.""" for sigs in self._data: if sigs[-1][0] == self._sig: @@ -88,7 +90,7 @@ def add(self, arg_dict, match): assert self._cur_data is not None self._cur_data.append([(self._sig, arg_dict, match)]) - def get(self): + def get(self) -> list[list[tuple[Any, Any, Any]]]: """Gets the matches.""" return self._data @@ -122,7 +124,7 @@ def make(cls, name, ctx, module, pyval_name=None): self.module = module return self - def __init__(self, name, signatures, kind, decorators, ctx): + def __init__(self, name, signatures, kind, decorators, ctx) -> None: super().__init__(name, ctx) assert signatures self.kind = kind @@ -142,7 +144,13 @@ def __init__(self, name, signatures, kind, decorators, ctx): sig.name = self.name self.decorators = [d.type.name for d in decorators] - def property_get(self, callself, is_class=False): + def property_get(self, callself, is_class=False) -> ( + _function_base.BoundFunction | + _function_base.ClassMethod | + _function_base.Function | + _function_base.Property | + _function_base.StaticMethod + ): if self.kind == pytd.MethodKind.STATICMETHOD: if is_class: # Binding the function to None rather than not binding it tells @@ -172,7 +180,7 @@ def property_get(self, callself, is_class=False): def argcount(self, _): return min(sig.signature.mandatory_param_count() for sig in self.signatures) - def _log_args(self, arg_values_list, level=0, logged=None): + def _log_args(self, arg_values_list, level=0, logged=None) -> None: """Log the argument values.""" if log.isEnabledFor(logging.DEBUG): if logged is None: @@ -198,7 +206,7 @@ def _log_args(self, arg_values_list, level=0, logged=None): logged | {value.data}, ) - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: # TODO(b/159052609): We should be passing function signatures to simplify. if len(self.signatures) == 1: args = args.simplify(node, self.ctx, self.signatures[0].signature) @@ -342,7 +350,7 @@ def compatible_with(new, existing, view): node = abstract_utils.apply_mutations(node, all_mutations.__iter__) return node, retvar - def _get_mutation_to_unknown(self, node, values): + def _get_mutation_to_unknown(self, node, values) -> list[function.Mutation]: """Mutation for making all type parameters in a list of instances "unknown". This is used if we call a function that has mutable parameters and @@ -384,7 +392,9 @@ def _can_match_multiple(self, args): # An opaque *args or **kwargs behaves like an unknown. return args.has_opaque_starargs_or_starstarargs() - def _call_with_signatures(self, node, func, args, view, signatures): + def _call_with_signatures( + self, node: _T0, func, args, view, signatures + ) -> tuple[_T0, Any, Any]: """Perform a function call that involves multiple signatures.""" ret_type = self._combine_multiple_returns(signatures) if self.ctx.options.protocols and isinstance(ret_type, pytd.AnythingType): @@ -502,7 +512,7 @@ def _match_args_sequentially(self, node, args, alias_map, match_all_views): raise error return matched_signatures.get() - def set_function_defaults(self, node, defaults_var): + def set_function_defaults(self, node, defaults_var) -> None: """Attempts to set default arguments for a function's signatures. If defaults_var is not an unambiguous tuple (i.e. one that can be processed @@ -543,7 +553,7 @@ class PyTDSignature(utils.ContextWeakrefMixin): type. """ - def __init__(self, name, pytd_sig, ctx): + def __init__(self, name, pytd_sig, ctx) -> None: super().__init__(ctx) self.name = name self.pytd_sig = pytd_sig @@ -565,7 +575,7 @@ def __init__(self, name, pytd_sig, ctx): log.error("New: %s", pytd_utils.Print(p.mutated_type)) raise SignatureMutationError(pytd_sig) from e - def _map_args(self, node, args): + def _map_args(self, node, args) -> tuple[Any, dict]: """Map the passed arguments to a name->binding dictionary. Args: @@ -653,7 +663,7 @@ def _map_args(self, node, args): return formal_args, arg_dict - def _fill_in_missing_parameters(self, node, args, arg_dict): + def _fill_in_missing_parameters(self, node, args, arg_dict) -> None: for p in self.pytd_sig.params: if p.name not in arg_dict: if ( @@ -667,7 +677,9 @@ def _fill_in_missing_parameters(self, node, args, arg_dict): # Assume the missing parameter is filled in by *args or **kwargs. arg_dict[p.name] = self.ctx.new_unsolvable(node) - def substitute_formal_args(self, node, args, match_all_views, keep_all_views): + def substitute_formal_args( + self, node, args, match_all_views, keep_all_views + ) -> tuple[Any, Any]: """Substitute matching args into this signature. Used by PyTDFunction.""" formal_args, arg_dict = self._map_args(node, args) self._fill_in_missing_parameters(node, args, arg_dict) @@ -713,7 +725,7 @@ def _paramspec_signature(self, callable_type, return_value, subst): ret.AddBinding(_function_base.SimpleFunction(ret_sig, self.ctx)) return ret - def _handle_paramspec(self, node, key, ret_map): + def _handle_paramspec(self, node, key, ret_map) -> None: """Construct a new function based on ParamSpec matching.""" return_callable, subst = key val = self.ctx.convert.constant_to_value( @@ -737,7 +749,9 @@ def _handle_paramspec(self, node, key, ret_map): if ret: ret_map[key] = ret - def call_with_args(self, node, func, arg_dict, match, ret_map): + def call_with_args( + self, node, func, arg_dict, match, ret_map + ) -> tuple[Any, Any, Any]: """Call this signature. Used by PyTDFunction.""" subst = match.subst ret = self.pytd_sig.return_type @@ -783,7 +797,7 @@ def call_with_args(self, node, func, arg_dict, match, ret_map): return node, ret_map[t], mutations @classmethod - def _collect_mutated_parameters(cls, typ, mutated_type): + def _collect_mutated_parameters(cls, typ, mutated_type) -> list: if not mutated_type: return [] if isinstance(typ, pytd.UnionType) and isinstance( @@ -815,7 +829,9 @@ def _collect_mutated_parameters(cls, typ, mutated_type): list(zip(mutated_type.base_type.cls.template, mutated_type.parameters)) ] - def _get_mutation(self, node, arg_dict, subst, retvar): + def _get_mutation( + self, node, arg_dict, subst, retvar + ) -> list[function.Mutation]: """Mutation for changing the type parameters of mutable arguments. This will adjust the type parameters as needed for pytd functions like: diff --git a/pytype/abstract/_singletons.py b/pytype/abstract/_singletons.py index 3805e09cf..711bb3c09 100644 --- a/pytype/abstract/_singletons.py +++ b/pytype/abstract/_singletons.py @@ -1,15 +1,21 @@ """Singleton abstract values.""" import logging +from typing import Any, Optional, TypeVar from pytype import datatypes from pytype.abstract import _base from pytype.pytd import escape from pytype.pytd import pytd from pytype.pytd import pytd_utils +from pytype.pytd.pytd import Class from pytype.typegraph import cfg +from pytype.types import types -log = logging.getLogger(__name__) + +_T0 = TypeVar("_T0") + +log: logging.Logger = logging.getLogger(__name__) class Unknown(_base.BaseValue): @@ -28,9 +34,9 @@ class Unknown(_base.BaseValue): _current_id = 0 # For simplicity, Unknown doesn't emulate descriptors: - IGNORED_ATTRIBUTES = ["__get__", "__set__", "__getattribute__"] + IGNORED_ATTRIBUTES: list[str] = ["__get__", "__set__", "__getattribute__"] - def __init__(self, ctx): + def __init__(self, ctx) -> None: name = escape.unknown(Unknown._current_id) super().__init__(name, ctx) self.members = datatypes.MonitorDict() @@ -43,7 +49,7 @@ def __init__(self, ctx): def compute_mro(self): return self.default_mro() - def get_fullhash(self, seen=None): + def get_fullhash(self, seen=None) -> int: # Unknown needs its own implementation of get_fullhash to ensure equivalent # Unknowns produce the same hash. "Equivalent" in this case means "has the # same members," so member names are used in the hash instead of id(). @@ -61,7 +67,7 @@ def _to_pytd(cls, node, v): return v.to_pytd_type(node) @classmethod - def _make_params(cls, node, args, kwargs): + def _make_params(cls, node, args, kwargs) -> tuple: """Convert a list of types/variables to pytd parameters.""" def _make_param(name, p): @@ -97,14 +103,14 @@ def get_special_attribute(self, node, name, valself): ) return new - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: ret = self.ctx.convert.create_new_unknown( node, source=self.owner, action="call:" + self.name ) self._calls.append((args.posargs, args.namedargs, ret)) return node, ret - def argcount(self, _): + def argcount(self, _) -> int: return 0 def to_variable(self, node): @@ -114,7 +120,7 @@ def to_variable(self, node): self.ctx.vm.trace_unknown(self.class_name, val) return v - def to_structural_def(self, node, class_name): + def to_structural_def(self, node, class_name) -> Class: """Convert this Unknown to a pytd.Class.""" self_param = ( pytd.Parameter( @@ -168,7 +174,7 @@ class Singleton(_base.BaseValue): This is essentially an ABC for Unsolvable, Empty, and others. """ - _instance = None + _instance: Optional["Singleton"] = None def __new__(cls, *args, **kwargs): # If cls is a subclass of a subclass of Singleton, cls._instance will be @@ -182,10 +188,10 @@ def get_special_attribute(self, node, name, valself): del name, valself return self.to_variable(node) - def compute_mro(self): + def compute_mro(self) -> tuple[types.BaseValue, Any]: return self.default_mro() - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: del func, args return node, self.to_variable(node) @@ -219,14 +225,14 @@ def f(): convert.Converter._function_to_def and tracer_vm.CallTracer.pytd_for_types. """ - def __init__(self, ctx): + def __init__(self, ctx) -> None: super().__init__("empty", ctx) class Deleted(Empty): """Assigned to variables that have del called on them.""" - def __init__(self, line, ctx): + def __init__(self, line, ctx) -> None: super().__init__(ctx) self.line = line self.name = "deleted" @@ -248,13 +254,13 @@ class Unsolvable(Singleton): only need one. """ - IGNORED_ATTRIBUTES = ["__get__", "__set__", "__getattribute__"] + IGNORED_ATTRIBUTES: list[str] = ["__get__", "__set__", "__getattribute__"] # Since an unsolvable gets generated e.g. for every unresolved import, we # can have multiple circular Unsolvables in a class' MRO. Treat those special. SINGLETON = True - def __init__(self, ctx): + def __init__(self, ctx) -> None: super().__init__("unsolveable", ctx) def get_special_attribute(self, node, name, _): @@ -264,7 +270,7 @@ def get_special_attribute(self, node, name, _): else: return self.to_variable(node) - def argcount(self, _): + def argcount(self, _) -> int: return 0 diff --git a/pytype/abstract/_special_classes.py b/pytype/abstract/_special_classes.py index 989074b5e..fa089cc8b 100644 --- a/pytype/abstract/_special_classes.py +++ b/pytype/abstract/_special_classes.py @@ -68,17 +68,17 @@ def maybe_build_from_mro(self, abstract_cls, name, pytd_cls): class _TypedDictBuilder(_Builder): """Build a typed dict.""" - CLASSES = ("typing.TypedDict", "typing_extensions.TypedDict") + CLASSES: tuple[str, str] = ("typing.TypedDict", "typing_extensions.TypedDict") - def matches_class(self, c): + def matches_class(self, c) -> bool: return c.name in self.CLASSES - def matches_base(self, c): + def matches_base(self, c) -> bool: return any( isinstance(b, pytd.ClassType) and self.matches_class(b) for b in c.bases ) - def matches_mro(self, c): + def matches_mro(self, c) -> bool: # Check if we have typed dicts in the MRO by seeing if we have already # created a TypedDictClass for one of the ancestor classes. return any( @@ -96,17 +96,17 @@ def make_derived_class(self, name, pytd_cls): class _NamedTupleBuilder(_Builder): """Build a namedtuple.""" - CLASSES = ("typing.NamedTuple",) + CLASSES: tuple[str] = ("typing.NamedTuple",) - def matches_class(self, c): + def matches_class(self, c) -> bool: return c.name in self.CLASSES - def matches_base(self, c): + def matches_base(self, c) -> bool: return any( isinstance(b, pytd.ClassType) and self.matches_class(b) for b in c.bases ) - def matches_mro(self, c): + def matches_mro(self, c) -> bool: # We only create namedtuples by direct inheritance return False @@ -117,7 +117,10 @@ def make_derived_class(self, name, pytd_cls): return self.convert.make_namedtuple(name, pytd_cls) -_BUILDERS = (_TypedDictBuilder, _NamedTupleBuilder) +_BUILDERS: tuple[type[_TypedDictBuilder], type[_NamedTupleBuilder]] = ( + _TypedDictBuilder, + _NamedTupleBuilder, +) def maybe_build_from_pytd(name, pytd_cls, ctx): diff --git a/pytype/abstract/_typing.py b/pytype/abstract/_typing.py index 13acff670..3a5710fbb 100644 --- a/pytype/abstract/_typing.py +++ b/pytype/abstract/_typing.py @@ -4,6 +4,7 @@ import dataclasses import logging import typing +from typing import TypeVar from pytype import datatypes from pytype.abstract import _base @@ -14,7 +15,12 @@ from pytype.abstract import mixin from pytype.pytd import pytd_utils -log = logging.getLogger(__name__) + +_T0 = TypeVar("_T0") +_TUnion = TypeVar("_TUnion", bound="Union") +_T_TypeVariable = TypeVar("_T_TypeVariable", bound="_TypeVariable") + +log: logging.Logger = logging.getLogger(__name__) def _get_container_type_key(container): @@ -32,14 +38,14 @@ def __init__(self, name, ctx): mixin.HasSlots.init_mixin(self) self.set_native_slot("__getitem__", self.getitem_slot) - def getitem_slot(self, node, slice_var): + def getitem_slot(self, node: _T0, slice_var) -> tuple[_T0, typing.Any]: """Custom __getitem__ implementation.""" slice_content = abstract_utils.maybe_extract_tuple(slice_var) inner, ellipses = self._build_inner(slice_content) value = self._build_value(node, tuple(inner), ellipses) return node, value.to_variable(node) - def _build_inner(self, slice_content): + def _build_inner(self, slice_content) -> tuple[list, set[int]]: """Build the list of parameters. Args: @@ -69,7 +75,7 @@ def _build_inner(self, slice_content): def _build_value(self, node, inner, ellipses): raise NotImplementedError(self.__class__.__name__) - def __repr__(self): + def __repr__(self) -> str: return f"AnnotationClass({self.name})" def _get_class(self): @@ -79,11 +85,11 @@ def _get_class(self): class AnnotationContainer(AnnotationClass): """Implementation of X[...] for annotations.""" - def __init__(self, name, ctx, base_cls): + def __init__(self, name, ctx, base_cls) -> None: super().__init__(name, ctx) self.base_cls = base_cls - def __repr__(self): + def __repr__(self) -> str: return f"AnnotationContainer({self.name})" def _sub_annotation( @@ -374,14 +380,16 @@ def _build_value(self, node, inner, ellipses): self.ctx.errorlog.invalid_annotation(self.ctx.vm.frames, e.annot, e.error) return self.ctx.convert.unsolvable - def call(self, node, func, args, alias_map=None): + def call( + self, node, func, args, alias_map=None + ) -> tuple[typing.Any, typing.Any]: return self._call_helper(node, self.base_cls, func, args) class _TypeVariableInstance(_base.BaseValue): """An instance of a type parameter.""" - def __init__(self, param, instance, ctx): + def __init__(self, param, instance, ctx) -> None: super().__init__(param.name, ctx) self.cls = self.param = param self.instance = instance @@ -391,7 +399,9 @@ def __init__(self, param, instance, ctx): def full_name(self): return f"{self.scope}.{self.name}" if self.scope else self.name - def call(self, node, func, args, alias_map=None): + def call( + self, node: _T0, func, args, alias_map=None + ) -> tuple[typing.Any, typing.Any]: var = self.instance.get_instance_type_parameter(self.name) if var.bindings: return function.call_function(self.ctx, node, var, args) @@ -403,10 +413,10 @@ def __eq__(self, other): return self.param == other.param and self.instance == other.instance return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return hash((self.param, self.instance)) - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.name!r})" @@ -434,7 +444,7 @@ def __init__( covariant=False, contravariant=False, scope=None, - ): + ) -> None: super().__init__(name, ctx) # TODO(b/217789659): PEP-612 does not mention constraints, but ParamSpecs # ignore all the extra parameters anyway.. @@ -453,7 +463,7 @@ def module(self, module): def full_name(self): return f"{self.scope}.{self.name}" if self.scope else self.name - def is_generic(self): + def is_generic(self) -> bool: return not self.constraints and not self.bound def copy(self): @@ -484,10 +494,10 @@ def __eq__(self, other): ) return NotImplemented - def __ne__(self, other): + def __ne__(self, other) -> bool: return not self == other - def __hash__(self): + def __hash__(self) -> int: return hash(( self.name, self.constraints, @@ -496,7 +506,7 @@ def __hash__(self): self.contravariant, )) - def __repr__(self): + def __repr__(self) -> str: return "{!s}({!r}, constraints={!r}, bound={!r}, module={!r})".format( self.__class__.__name__, self.name, @@ -522,7 +532,7 @@ def instantiate(self, node, container=None): var.AddBinding(self.ctx.convert.unsolvable, [], node) return var - def update_official_name(self, name): + def update_official_name(self, name) -> None: if self.name != name: message = ( f"TypeVar({self.name!r}) must be stored as {self.name!r}, " @@ -530,26 +540,28 @@ def update_official_name(self, name): ) self.ctx.errorlog.invalid_typevar(self.ctx.vm.frames, message) - def call(self, node, func, args, alias_map=None): + def call( + self, node: _T0, func, args, alias_map=None + ) -> tuple[_T0, typing.Any]: return node, self.instantiate(node) class TypeParameter(_TypeVariable): """Parameter of a type (typing.TypeVar).""" - _INSTANCE_CLASS = TypeParameterInstance + _INSTANCE_CLASS: type[TypeParameterInstance] = TypeParameterInstance class ParamSpec(_TypeVariable): """Parameter of a callable type (typing.ParamSpec).""" - _INSTANCE_CLASS = ParamSpecInstance + _INSTANCE_CLASS: type[ParamSpecInstance] = ParamSpecInstance class ParamSpecArgs(_base.BaseValue): """ParamSpec.args.""" - def __init__(self, paramspec, ctx): + def __init__(self, paramspec, ctx) -> None: super().__init__(f"{paramspec.name}.args", ctx) self.paramspec = paramspec @@ -560,7 +572,7 @@ def instantiate(self, node, container=None): class ParamSpecKwargs(_base.BaseValue): """ParamSpec.kwargs.""" - def __init__(self, paramspec, ctx): + def __init__(self, paramspec, ctx) -> None: super().__init__(f"{paramspec.name}.kwargs", ctx) self.paramspec = paramspec @@ -571,7 +583,7 @@ def instantiate(self, node, container=None): class Concatenate(_base.BaseValue): """Concatenation of args and ParamSpec.""" - def __init__(self, params, ctx): + def __init__(self, params, ctx) -> None: super().__init__("Concatenate", ctx) self.args = params[:-1] self.paramspec = params[-1] @@ -591,7 +603,7 @@ def get_args(self): # Satisfies the same interface as abstract.CallableClass return self.args - def __repr__(self): + def __repr__(self) -> str: args = ", ".join(list(map(repr, self.args)) + [self.paramspec.name]) return f"Concatenate[{args}]" @@ -605,7 +617,7 @@ class Union(_base.BaseValue, mixin.NestedAnnotation, mixin.HasSlots): options: Iterable of instances of BaseValue. """ - def __init__(self, options, ctx): + def __init__(self, options, ctx) -> None: super().__init__("Union", ctx) assert options self.options = list(options) @@ -616,7 +628,7 @@ def __init__(self, options, ctx): mixin.HasSlots.init_mixin(self) self.set_native_slot("__getitem__", self.getitem_slot) - def __repr__(self): + def __repr__(self) -> str: if self._printing: # recursion detected printed_contents = "..." else: @@ -625,15 +637,15 @@ def __repr__(self): self._printing = False return f"{self.name}[{printed_contents}]" - def __eq__(self, other): + def __eq__(self, other) -> bool: if isinstance(other, type(self)): return self.options == other.options return NotImplemented - def __ne__(self, other): + def __ne__(self, other) -> bool: return not self == other - def __hash__(self): + def __hash__(self) -> int: # Use the names of the parameter values to approximate a hash, to avoid # infinite recursion on recursive type annotations. return hash(tuple(o.full_name for o in self.options)) @@ -648,7 +660,7 @@ def _get_class(self): else: return classes.pop() - def getitem_slot(self, node, slice_var): + def getitem_slot(self, node: _T0, slice_var) -> tuple[_T0, typing.Any]: """Custom __getitem__ implementation.""" slice_content = abstract_utils.maybe_extract_tuple(slice_var) params = self.ctx.annotation_utils.get_type_parameters(self) @@ -692,23 +704,25 @@ def instantiate(self, node, container=None): var.PasteVariable(instance, node) return var - def call(self, node, func, args, alias_map=None): + def call( + self, node, func, args, alias_map=None + ) -> tuple[typing.Any, typing.Any]: var = self.ctx.program.NewVariable(self.options, [], node) return function.call_function(self.ctx, node, var, args) - def get_formal_type_parameter(self, t): + def get_formal_type_parameter(self: _TUnion, t) -> _TUnion: new_options = [ option.get_formal_type_parameter(t) for option in self.options ] return Union(new_options, self.ctx) - def get_inner_types(self): + def get_inner_types(self) -> enumerate: return enumerate(self.options) - def update_inner_type(self, key, typ): + def update_inner_type(self, key, typ) -> None: self.options[key] = typ - def replace(self, inner_types): + def replace(self: _TUnion, inner_types) -> _TUnion: return self.__class__((v for _, v in sorted(inner_types)), self.ctx) @@ -726,7 +740,7 @@ class LateAnnotation: Use `x.is_late_annotation()` to check whether x is a late annotation. """ - _RESOLVING = object() + _RESOLVING: typing.Any = object() def __init__(self, expr, stack, ctx, *, typing_imports=None): self.expr = expr @@ -784,7 +798,7 @@ def unflatten_expr(self): ) return self.expr - def __repr__(self): + def __repr__(self) -> str: return "LateAnnotation({!r}, resolved={!r})".format( self.expr, self._type if self.resolved else None ) @@ -792,10 +806,10 @@ def __repr__(self): # __hash__ and __eq__ need to be explicitly defined for Python to use them in # set/dict comparisons. - def __hash__(self): + def __hash__(self) -> int: return hash(self._type) if self.resolved else hash(self.expr) - def __eq__(self, other): + def __eq__(self, other) -> bool: return hash(self) == hash(other) def __getattribute__(self, name): @@ -814,7 +828,7 @@ def __setattr__(self, name, value): def __contains__(self, name): return self.resolved and name in self._type - def resolve(self, node, f_globals, f_locals): + def resolve(self, node, f_globals, f_locals) -> None: """Resolve the late annotation.""" if self.resolved: return @@ -851,7 +865,7 @@ def resolve(self, node, f_globals, f_locals): self.resolved = True log.info("Resolved late annotation %r to %r", self.expr, self._type) - def set_type(self, typ): + def set_type(self, typ) -> None: # Used by annotation_utils.sub_one_annotation to substitute values into # recursive aliases. assert not self.resolved @@ -882,10 +896,10 @@ def get_special_attribute(self, node, name, valself): return container.get_special_attribute(node, name, valself) return self._type.get_special_attribute(node, name, valself) - def is_late_annotation(self): + def is_late_annotation(self) -> bool: return True - def is_recursive(self): + def is_recursive(self) -> bool: """Check whether this is a recursive type.""" if not self.resolved: return False @@ -909,7 +923,7 @@ def __init__(self, annotation, ctx): super().__init__("FinalAnnotation", ctx) self.annotation = annotation - def __repr__(self): + def __repr__(self) -> str: return f"Final[{self.annotation}]" def instantiate(self, node, container=None): diff --git a/pytype/abstract/abstract.py b/pytype/abstract/abstract.py index d99bdb3fc..4223bada6 100644 --- a/pytype/abstract/abstract.py +++ b/pytype/abstract/abstract.py @@ -18,85 +18,122 @@ from pytype.abstract import class_mixin from pytype.abstract import mixin -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) # For simplicity, we pretend all abstract values are defined in abstract.py. -BaseValue = _base.BaseValue +BaseValue: type[_base.BaseValue] = _base.BaseValue # These are technically mixins, but we use them a lot in isinstance() checks. -Class = class_mixin.Class -PythonConstant = mixin.PythonConstant - -BuildClass = _classes.BuildClass -InterpreterClass = _classes.InterpreterClass -PyTDClass = _classes.PyTDClass -FunctionPyTDClass = _classes.FunctionPyTDClass -ParameterizedClass = _classes.ParameterizedClass -CallableClass = _classes.CallableClass -LiteralClass = _classes.LiteralClass -TupleClass = _classes.TupleClass - -Function = _function_base.Function -NativeFunction = _function_base.NativeFunction -BoundFunction = _function_base.BoundFunction -BoundInterpreterFunction = _function_base.BoundInterpreterFunction -BoundPyTDFunction = _function_base.BoundPyTDFunction -ClassMethod = _function_base.ClassMethod -StaticMethod = _function_base.StaticMethod -Property = _function_base.Property -SignedFunction = _function_base.SignedFunction -SimpleFunction = _function_base.SimpleFunction - -SimpleValue = _instance_base.SimpleValue -Instance = _instance_base.Instance - -LazyConcreteDict = _instances.LazyConcreteDict -ConcreteValue = _instances.ConcreteValue -Module = _instances.Module -Coroutine = _instances.Coroutine -Iterator = _instances.Iterator -BaseGenerator = _instances.BaseGenerator -AsyncGenerator = _instances.AsyncGenerator -Generator = _instances.Generator -Tuple = _instances.Tuple -List = _instances.List -Dict = _instances.Dict -AnnotationsDict = _instances.AnnotationsDict -Splat = _instances.Splat -SequenceLength = _instances.SequenceLength - -InterpreterFunction = _interpreter_function.InterpreterFunction - -PyTDFunction = _pytd_function.PyTDFunction -PyTDSignature = _pytd_function.PyTDSignature -SignatureMutationError = _pytd_function.SignatureMutationError - -Unknown = _singletons.Unknown -Singleton = _singletons.Singleton -Empty = _singletons.Empty -Deleted = _singletons.Deleted -Unsolvable = _singletons.Unsolvable -Null = _singletons.Null - -AnnotationClass = _typing.AnnotationClass -AnnotationContainer = _typing.AnnotationContainer -ParamSpec = _typing.ParamSpec -ParamSpecArgs = _typing.ParamSpecArgs -ParamSpecKwargs = _typing.ParamSpecKwargs -ParamSpecInstance = _typing.ParamSpecInstance -Concatenate = _typing.Concatenate -TypeParameter = _typing.TypeParameter -TypeParameterInstance = _typing.TypeParameterInstance -Union = _typing.Union -LateAnnotation = _typing.LateAnnotation -FinalAnnotation = _typing.FinalAnnotation - -AMBIGUOUS = (Unknown, Unsolvable) -AMBIGUOUS_OR_EMPTY = AMBIGUOUS + (Empty,) -FUNCTION_TYPES = (BoundFunction, Function) -INTERPRETER_FUNCTION_TYPES = (BoundInterpreterFunction, InterpreterFunction) -PYTD_FUNCTION_TYPES = (BoundPyTDFunction, PyTDFunction) -TYPE_VARIABLE_TYPES = (TypeParameter, ParamSpec) -TYPE_VARIABLE_INSTANCES = (TypeParameterInstance, ParamSpecInstance) +Class: type[class_mixin.Class] = class_mixin.Class +PythonConstant: type[mixin.PythonConstant] = mixin.PythonConstant + +BuildClass: type[_classes.BuildClass] = _classes.BuildClass +InterpreterClass: type[_classes.InterpreterClass] = _classes.InterpreterClass +PyTDClass: type[_classes.PyTDClass] = _classes.PyTDClass +FunctionPyTDClass: type[_classes.FunctionPyTDClass] = _classes.FunctionPyTDClass +ParameterizedClass: type[_classes.ParameterizedClass] = ( + _classes.ParameterizedClass +) +CallableClass: type[_classes.CallableClass] = _classes.CallableClass +LiteralClass: type[_classes.LiteralClass] = _classes.LiteralClass +TupleClass: type[_classes.TupleClass] = _classes.TupleClass + +Function: type[_function_base.Function] = _function_base.Function +NativeFunction: type[_function_base.NativeFunction] = ( + _function_base.NativeFunction +) +BoundFunction: type[_function_base.BoundFunction] = _function_base.BoundFunction +BoundInterpreterFunction: type[_function_base.BoundInterpreterFunction] = ( + _function_base.BoundInterpreterFunction +) +BoundPyTDFunction: type[_function_base.BoundPyTDFunction] = ( + _function_base.BoundPyTDFunction +) +ClassMethod: type[_function_base.ClassMethod] = _function_base.ClassMethod +StaticMethod: type[_function_base.StaticMethod] = _function_base.StaticMethod +Property: type[_function_base.Property] = _function_base.Property +SignedFunction: type[_function_base.SignedFunction] = ( + _function_base.SignedFunction +) +SimpleFunction: type[_function_base.SimpleFunction] = ( + _function_base.SimpleFunction +) + +SimpleValue: type[_instance_base.SimpleValue] = _instance_base.SimpleValue +Instance: type[_instance_base.Instance] = _instance_base.Instance + +LazyConcreteDict: type[_instances.LazyConcreteDict] = ( + _instances.LazyConcreteDict +) +ConcreteValue: type[_instances.ConcreteValue] = _instances.ConcreteValue +Module: type[_instances.Module] = _instances.Module +Coroutine: type[_instances.Coroutine] = _instances.Coroutine +Iterator: type[_instances.Iterator] = _instances.Iterator +BaseGenerator: type[_instances.BaseGenerator] = _instances.BaseGenerator +AsyncGenerator: type[_instances.AsyncGenerator] = _instances.AsyncGenerator +Generator: type[_instances.Generator] = _instances.Generator +Tuple: type[_instances.Tuple] = _instances.Tuple +List: type[_instances.List] = _instances.List +Dict: type[_instances.Dict] = _instances.Dict +AnnotationsDict: type[_instances.AnnotationsDict] = _instances.AnnotationsDict +Splat: type[_instances.Splat] = _instances.Splat +SequenceLength: type[_instances.SequenceLength] = _instances.SequenceLength + +InterpreterFunction: type[_interpreter_function.InterpreterFunction] = ( + _interpreter_function.InterpreterFunction +) + +PyTDFunction: type[_pytd_function.PyTDFunction] = _pytd_function.PyTDFunction +PyTDSignature: type[_pytd_function.PyTDSignature] = _pytd_function.PyTDSignature +SignatureMutationError: type[_pytd_function.SignatureMutationError] = ( + _pytd_function.SignatureMutationError +) + +Unknown: type[_singletons.Unknown] = _singletons.Unknown +Singleton: type[_singletons.Singleton] = _singletons.Singleton +Empty: type[_singletons.Empty] = _singletons.Empty +Deleted: type[_singletons.Deleted] = _singletons.Deleted +Unsolvable: type[_singletons.Unsolvable] = _singletons.Unsolvable +Null: type[_singletons.Null] = _singletons.Null + +AnnotationClass: type[_typing.AnnotationClass] = _typing.AnnotationClass +AnnotationContainer: type[_typing.AnnotationContainer] = ( + _typing.AnnotationContainer +) +ParamSpec: type[_typing.ParamSpec] = _typing.ParamSpec +ParamSpecArgs: type[_typing.ParamSpecArgs] = _typing.ParamSpecArgs +ParamSpecKwargs: type[_typing.ParamSpecKwargs] = _typing.ParamSpecKwargs +ParamSpecInstance: type[_typing.ParamSpecInstance] = _typing.ParamSpecInstance +Concatenate: type[_typing.Concatenate] = _typing.Concatenate +TypeParameter: type[_typing.TypeParameter] = _typing.TypeParameter +TypeParameterInstance: type[_typing.TypeParameterInstance] = ( + _typing.TypeParameterInstance +) +Union: type[_typing.Union] = _typing.Union +LateAnnotation: type[_typing.LateAnnotation] = _typing.LateAnnotation +FinalAnnotation: type[_typing.FinalAnnotation] = _typing.FinalAnnotation + +AMBIGUOUS: tuple[type[Unknown], type[Unsolvable]] = (Unknown, Unsolvable) +AMBIGUOUS_OR_EMPTY: tuple[type[Unknown], type[Unsolvable], type[Empty]] = ( + AMBIGUOUS + (Empty,) +) +FUNCTION_TYPES: tuple[type[BoundFunction], type[Function]] = ( + BoundFunction, + Function, +) +INTERPRETER_FUNCTION_TYPES: tuple[ + type[BoundInterpreterFunction], type[InterpreterFunction] +] = (BoundInterpreterFunction, InterpreterFunction) +PYTD_FUNCTION_TYPES: tuple[type[BoundPyTDFunction], type[PyTDFunction]] = ( + BoundPyTDFunction, + PyTDFunction, +) +TYPE_VARIABLE_TYPES: tuple[type[TypeParameter], type[ParamSpec]] = ( + TypeParameter, + ParamSpec, +) +TYPE_VARIABLE_INSTANCES: tuple[ + type[TypeParameterInstance], type[ParamSpecInstance] +] = (TypeParameterInstance, ParamSpecInstance) AmbiguousOrEmptyType = Unknown | Unsolvable | Empty diff --git a/pytype/abstract/abstract_utils.py b/pytype/abstract/abstract_utils.py index 196d79ee5..031d0eee0 100644 --- a/pytype/abstract/abstract_utils.py +++ b/pytype/abstract/abstract_utils.py @@ -1,12 +1,13 @@ """Utilities for abstract.py.""" import collections -from collections.abc import Collection, Iterable, Mapping, Sequence +from collections.abc import Collection, Generator, Iterable, Mapping, Sequence import dataclasses import logging -from typing import Any +from typing import Any, TypeVar from pytype import datatypes +from pytype.datatypes import AccessTrackingDict from pytype.pyc import opcodes from pytype.pyc import pyc from pytype.pytd import pytd @@ -14,16 +15,21 @@ from pytype.typegraph import cfg from pytype.typegraph import cfg_utils -log = logging.getLogger(__name__) + +_T0 = TypeVar("_T0") +_T1 = TypeVar("_T1") +_TLocal = TypeVar("_TLocal", bound="Local") + +log: logging.Logger = logging.getLogger(__name__) # Type aliases _ArgsDictType = dict[str, cfg.Variable] # We can't import some modules here due to circular deps. -_ContextType = Any # context.Context -_BaseValueType = Any # abstract.BaseValue -_ParameterizedClassType = Any # abstract.ParameterizedClass -_TypeParamType = Any # abstract.TypeParameter +_ContextType: type[Any] = Any # context.Context +_BaseValueType: type[Any] = Any # abstract.BaseValue +_ParameterizedClassType: type[Any] = Any # abstract.ParameterizedClass +_TypeParamType: type[Any] = Any # abstract.TypeParameter # Type parameter names matching the ones in builtins.pytd and typing.pytd. T = "_T" @@ -34,14 +40,14 @@ RET = "_RET" # TODO(rechen): Stop supporting all variants except _HAS_DYNAMIC_ATTRIBUTES. -DYNAMIC_ATTRIBUTE_MARKERS = [ +DYNAMIC_ATTRIBUTE_MARKERS: list[str] = [ "HAS_DYNAMIC_ATTRIBUTES", "_HAS_DYNAMIC_ATTRIBUTES", "has_dynamic_attributes", ] # Names defined on every module/class that should be ignored in most cases. -TOP_LEVEL_IGNORE = frozenset({ +TOP_LEVEL_IGNORE: frozenset[str] = frozenset({ "__builtins__", "__doc__", "__file__", @@ -50,7 +56,7 @@ "__name__", "__annotations__", }) -CLASS_LEVEL_IGNORE = frozenset({ +CLASS_LEVEL_IGNORE: frozenset[str] = frozenset({ "__builtins__", "__class__", "__module__", @@ -60,7 +66,7 @@ "__annotations__", }) -TYPE_GUARDS = {"typing.TypeGuard", "typing.TypeIs"} +TYPE_GUARDS: set[str] = {"typing.TypeGuard", "typing.TypeIs"} # A dummy container object for use in instantiating type parameters. @@ -75,7 +81,7 @@ def __init__(self, container): self.container = container -DUMMY_CONTAINER = DummyContainer(None) +DUMMY_CONTAINER: DummyContainer = DummyContainer(None) class ConversionError(ValueError): @@ -97,7 +103,7 @@ def details(self): class GenericTypeError(Exception): """The error for user-defined generic types.""" - def __init__(self, annot, error): + def __init__(self, annot, error) -> None: super().__init__(annot, error) self.annot = annot self.error = error @@ -110,7 +116,7 @@ class ModuleLoadError(Exception): class AsInstance: """Wrapper, used for marking things that we want to convert to an instance.""" - def __init__(self, cls): + def __init__(self, cls) -> None: self.cls = cls @@ -149,7 +155,7 @@ def __init__( self.ctx = ctx @classmethod - def merge(cls, node, op, local1, local2): + def merge(cls: type[_TLocal], node, op, local1, local2) -> _TLocal: """Merges two locals.""" ctx = local1.ctx typ_values = set() @@ -165,7 +171,7 @@ def merge(cls, node, op, local1, local2): orig = local1.orig or local2.orig return cls(node, op, typ, orig, ctx) - def __repr__(self): + def __repr__(self) -> str: return f"Local(typ={self.typ}, orig={self.orig}, final={self.final})" @property @@ -176,7 +182,7 @@ def stack(self): def last_update_op(self): return self._ops[-1] - def update(self, node, op, typ, orig, final=False): + def update(self, node, op, typ, orig, final=False) -> None: """Update this variable's annotation and/or value.""" if op in self._ops: return @@ -210,10 +216,10 @@ def get_type(self, node, name): # Callers are expected to alias them like so: # _isinstance = abstract_utils._isinstance # pylint: disable=protected-access -_ISINSTANCE_CACHE = {} +_ISINSTANCE_CACHE: dict[str, Any] = {} -def _isinstance(obj, name_or_names): +def _isinstance(obj, name_or_names) -> bool: """Do an isinstance() call for a class defined in pytype.abstract. Args: @@ -273,7 +279,7 @@ def get_atomic_value(variable, constant_type=None, default=_None()): ) -def match_atomic_value(variable, typ=None): +def match_atomic_value(variable, typ=None) -> bool: try: get_atomic_value(variable, typ) except ConversionError: @@ -301,7 +307,7 @@ def get_atomic_python_constant(variable, constant_type=None): return atomic.ctx.convert.value_to_constant(atomic, constant_type) -def match_atomic_python_constant(variable, typ=None): +def match_atomic_python_constant(variable, typ=None) -> bool: try: get_atomic_python_constant(variable, typ) except ConversionError: @@ -309,7 +315,9 @@ def match_atomic_python_constant(variable, typ=None): return True -def get_views(variables, node): +def get_views( + variables, node +) -> Generator[AccessTrackingDict[None, None], Any, None]: """Get all possible views of the given variables at a particular node. For performance reasons, this method uses node.CanHaveCombination for @@ -426,7 +434,7 @@ def get_mro_bases(bases): return mro_bases -def _merge_type(t0, t1, name, cls): +def _merge_type(t0: _T0, t1: _T1, name, cls) -> _T0 | _T1: """Merge two types. Rules: Type `Any` can match any type, we will return the other type if one @@ -459,7 +467,7 @@ def _merge_type(t0, t1, name, cls): def parse_formal_type_parameters( base, prefix, formal_type_parameters, container=None -): +) -> None: """Parse type parameters from base class. Args: @@ -562,7 +570,9 @@ def maybe_extract_tuple(t): return v.pyval -def eval_expr(ctx, node, f_globals, f_locals, expr): +def eval_expr( + ctx, node, f_globals, f_locals, expr +) -> tuple[Any, EvaluationError | None]: """Evaluate an expression with the given node and globals.""" # This is used to resolve type comments and late annotations. # @@ -667,7 +677,7 @@ def is_indefinite_iterable(val: _BaseValueType) -> bool: return False -def is_var_indefinite_iterable(var): +def is_var_indefinite_iterable(var) -> bool: """True if all bindings of var are indefinite sequences.""" return all(is_indefinite_iterable(x) for x in var.data) @@ -689,7 +699,7 @@ def merged_type_parameter(node, var, param): return var.data[0].ctx.join_variables(node, params) -def is_var_splat(var): +def is_var_splat(var) -> bool: if var.data and _isinstance(var.data[0], "Splat"): # A splat should never have more than one binding, since we create and use # it immediately. @@ -800,7 +810,7 @@ def combine_substs( return () -def flatten(value, classes): +def flatten(value, classes) -> bool: """Flatten the contents of value into classes. If value is a Class, it is appended to classes. @@ -844,7 +854,7 @@ def flatten(value, classes): return True -def check_against_mro(ctx, target, class_spec): +def check_against_mro(ctx, target, class_spec) -> bool | None: """Check if any of the classes are in the target's MRO. Args: diff --git a/pytype/abstract/class_mixin.py b/pytype/abstract/class_mixin.py index 8ac12f8b1..303a05567 100644 --- a/pytype/abstract/class_mixin.py +++ b/pytype/abstract/class_mixin.py @@ -3,7 +3,7 @@ from collections.abc import Mapping, Sequence import dataclasses import logging -from typing import Any +from typing import Any, TypeVar from pytype import datatypes from pytype.abstract import abstract_utils @@ -14,18 +14,23 @@ from pytype.typegraph import cfg -log = logging.getLogger(__name__) +_T0 = TypeVar("_T0") +_TAttribute = TypeVar("_TAttribute", bound="Attribute") +_TClass = TypeVar("_TClass", bound="Class") + + +log: logging.Logger = logging.getLogger(__name__) _isinstance = abstract_utils._isinstance # pylint: disable=protected-access _make = abstract_utils._make # pylint: disable=protected-access -_InterpreterFunction = Any # can't import due to a circular dependency +_InterpreterFunction: Any = Any # can't import due to a circular dependency FunctionMapType = Mapping[str, Sequence[_InterpreterFunction]] # Classes have a metadata dictionary that can store arbitrary metadata for # various overlays. We define the dictionary keys here so that they can be # shared by abstract.py and the overlays. -_METADATA_KEYS = { +_METADATA_KEYS: dict[str, str] = { "dataclasses.dataclass": "__dataclass_fields__", # attr.s gets resolved to attr._make.attrs in pyi files but intercepted by # the attr overlay as attr.s when processing bytecode. @@ -49,7 +54,7 @@ } -def get_metadata_key(decorator): +def get_metadata_key(decorator) -> str | None: return _METADATA_KEYS.get(decorator) @@ -83,7 +88,9 @@ class Attribute: pytd_const: Any = None @classmethod - def from_pytd_constant(cls, const, ctx, *, kw_only=False): + def from_pytd_constant( + cls: type[_TAttribute], const, ctx, *, kw_only=False + ) -> _TAttribute: """Generate an Attribute from a pytd.Constant.""" typ = ctx.convert.constant_to_value(const.type) # We want to generate the default from the type, not from the value @@ -113,7 +120,7 @@ def to_pytd_constant(self): # will have been created from a parent PyTDClass. return self.pytd_const - def __repr__(self): + def __repr__(self) -> str: return str({ "name": self.name, "typ": self.typ, @@ -162,7 +169,7 @@ def f(self): class Class(metaclass=mixin.MixinMeta): # pylint: disable=undefined-variable """Mix-in to mark all class-like values.""" - overloads = ( + overloads: tuple[str, str, str, str, str, str] = ( "_get_class", "call", "compute_mro", @@ -171,12 +178,12 @@ class Class(metaclass=mixin.MixinMeta): # pylint: disable=undefined-variable "update_official_name", ) - def __new__(cls, *unused_args, **unused_kwds): + def __new__(cls: type[_TClass], *unused_args, **unused_kwds) -> _TClass: """Prevent direct instantiation.""" assert cls is not Class, "Cannot instantiate Class" return object.__new__(cls) - def init_mixin(self, metaclass): + def init_mixin(self, metaclass) -> None: """Mix-in equivalent of __init__.""" if metaclass is None: metaclass = self._get_inherited_metaclass() @@ -199,7 +206,7 @@ def init_mixin(self, metaclass): def _get_class(self): return self.ctx.convert.type_type - def bases(self): + def bases(self) -> list: return [] @property @@ -207,7 +214,7 @@ def all_formal_type_parameters(self): self._load_all_formal_type_parameters() return self._all_formal_type_parameters - def _load_all_formal_type_parameters(self): + def _load_all_formal_type_parameters(self) -> None: """Load _all_formal_type_parameters.""" if self._all_formal_type_parameters_loaded: return @@ -229,14 +236,14 @@ def get_own_attributes(self): """Get the attributes defined by this class.""" raise NotImplementedError(self.__class__.__name__) - def has_protocol_base(self): + def has_protocol_base(self) -> bool: """Returns whether this class inherits directly from typing.Protocol. Subclasses that may inherit from Protocol should override this method. """ return False - def _init_protocol_attributes(self): + def _init_protocol_attributes(self) -> None: """Compute this class's protocol attributes.""" if _isinstance(self, "ParameterizedClass"): self.protocol_attributes = self.base_cls.protocol_attributes @@ -283,7 +290,7 @@ def _init_protocol_attributes(self): protocol_attributes = {a for a in protocol_attributes if a not in cls} self.protocol_attributes = protocol_attributes - def _init_overrides_bool(self): + def _init_overrides_bool(self) -> None: """Compute and cache whether the class sets its own boolean value.""" # A class's instances can evaluate to False if it defines __bool__ or # __len__. @@ -301,7 +308,7 @@ def get_own_abstract_methods(self): """Get the abstract methods defined by this class.""" raise NotImplementedError(self.__class__.__name__) - def _init_abstract_methods(self): + def _init_abstract_methods(self) -> None: """Compute this class's abstract methods.""" # For the algorithm to run, abstract_methods needs to be populated with the # abstract methods defined by this class. We'll overwrite the attribute @@ -321,10 +328,10 @@ def _init_abstract_methods(self): abstract_methods |= {m for m in cls.abstract_methods if m in cls} self.abstract_methods = abstract_methods - def _has_explicit_abcmeta(self): + def _has_explicit_abcmeta(self) -> bool: return any(base.full_name == "abc.ABCMeta" for base in self.cls.mro) - def _has_implicit_abcmeta(self): + def _has_implicit_abcmeta(self) -> bool: """Whether the class should be considered implicitly abstract.""" # Protocols must be marked as abstract to get around the # [ignored-abstractmethod] check for interpreter classes. @@ -350,7 +357,7 @@ def is_abstract(self): self._has_explicit_abcmeta() or self._has_implicit_abcmeta() ) and bool(self.abstract_methods) - def is_test_class(self): + def is_test_class(self) -> bool: return any( base.full_name in ("unittest.TestCase", "unittest.case.TestCase") for base in self.mro @@ -424,7 +431,7 @@ def call_init_subclass(self, node): node = cls.init_subclass(node, self) return node - def get_own_new(self, node, value): + def get_own_new(self, node, value) -> tuple[Any, Any]: """Get this value's __new__ method, if it isn't object.__new__. Args: @@ -450,7 +457,7 @@ def get_own_new(self, node, value): return node, None return node, new - def _call_new_and_init(self, node, value, args): + def _call_new_and_init(self, node, value, args) -> tuple[Any, Any]: """Call __new__ if it has been overridden on the given value.""" node, new = self.get_own_new(node, value) if new is None: @@ -493,7 +500,7 @@ def _new_instance(self, container, node, args): self._instance_cache[key] = _make("Instance", self, self.ctx, container) return self._instance_cache[key] - def _check_not_instantiable(self): + def _check_not_instantiable(self) -> None: """Report [not-instantiable] if the class cannot be instantiated.""" # We report a not-instantiable error if all of the following are true: # - The class is abstract. @@ -512,7 +519,7 @@ def _check_not_instantiable(self): return self.ctx.errorlog.not_instantiable(self.ctx.vm.frames, self) - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: del alias_map # unused self._check_not_instantiable() node, variable = self._call_new_and_init(node, func, args) @@ -549,16 +556,16 @@ def get_special_attribute(self, node, name, valself): return container.get_special_attribute(node, name, valself) return Class.super(self.get_special_attribute)(node, name, valself) - def has_dynamic_attributes(self): + def has_dynamic_attributes(self) -> bool: return any(a in self for a in abstract_utils.DYNAMIC_ATTRIBUTE_MARKERS) - def compute_is_dynamic(self): + def compute_is_dynamic(self) -> bool: # This needs to be called after self.mro is set. return any( c.has_dynamic_attributes() for c in self.mro if isinstance(c, Class) ) - def compute_mro(self): + def compute_mro(self) -> tuple: """Compute the class precedence list (mro) according to C3.""" bases = abstract_utils.get_mro_bases(self.bases()) bases = [[self]] + [list(base.mro) for base in bases] + [list(bases)] @@ -578,7 +585,7 @@ def compute_mro(self): # calc MRO and replace them with original base classes return tuple(base2cls[base] for base in mro.MROMerge(newbases)) - def _get_mro_attrs_for_attrs(self, cls_attrs, metadata_key): + def _get_mro_attrs_for_attrs(self, cls_attrs, metadata_key) -> list: """Traverse the MRO and collect base class attributes for metadata_key.""" # For dataclasses, attributes preserve the ordering from the reversed MRO, # but derived classes can override the type of an attribute. For attrs, @@ -603,7 +610,7 @@ def _get_mro_attrs_for_attrs(self, cls_attrs, metadata_key): base_attrs.append(a) return base_attrs + cls_attrs - def _recompute_attrs_type_from_mro(self, all_attrs, type_params): + def _recompute_attrs_type_from_mro(self, all_attrs, type_params) -> None: """Traverse the MRO and apply Generic type params to class attributes. This IS REQUIRED for dataclass instances that inherits from a Generic. @@ -655,7 +662,7 @@ def _get_attrs_from_mro(self, cls_attrs, metadata_key): self._recompute_attrs_type_from_mro(all_attrs, type_params) return list(all_attrs.values()) - def record_attr_ordering(self, own_attrs): + def record_attr_ordering(self, own_attrs) -> None: """Records the order of attrs to write in the output pyi.""" self.metadata["attr_order"] = own_attrs @@ -692,7 +699,7 @@ def update_official_name(self, name: str) -> None: if isinstance(member, Class): member.update_official_name(f"{name}.{member.name}") - def _convert_str_tuple(self, field_name): + def _convert_str_tuple(self, field_name) -> tuple | None: """Convert __slots__ and similar fields from a Variable to a tuple.""" field_var = self.members.get(field_name) if field_var is None: @@ -721,7 +728,7 @@ def _convert_str_tuple(self, field_name): return None return tuple(self._mangle(s) for s in names) - def _mangle(self, name): + def _mangle(self, name) -> str: """Do name-mangling on an attribute name. See https://goo.gl/X85fHt. Python automatically converts a name like diff --git a/pytype/abstract/function.py b/pytype/abstract/function.py index 3c9263c97..29ebe9943 100644 --- a/pytype/abstract/function.py +++ b/pytype/abstract/function.py @@ -2,6 +2,7 @@ import abc import collections +from collections.abc import Generator import dataclasses import itertools import logging @@ -19,12 +20,17 @@ from pytype.typegraph import cfg_utils from pytype.types import types -log = logging.getLogger(__name__) +_T0 = TypeVar("_T0") +_T3 = TypeVar("_T3") +_TArgs = TypeVar("_TArgs", bound="Args") +_TSignature = TypeVar("_TSignature", bound="Signature") + +log: logging.Logger = logging.getLogger(__name__) _isinstance = abstract_utils._isinstance # pylint: disable=protected-access _make = abstract_utils._make # pylint: disable=protected-access -def argname(i): +def argname(i) -> str: """Get a name for an unnamed positional argument, given its position.""" return "_" + str(i) @@ -145,10 +151,10 @@ def has_return_annotation(self): def has_param_annotations(self): return bool(self.annotations.keys() - {"return"}) - def has_default(self, name): + def has_default(self, name) -> bool: return name in self.defaults - def add_scope(self, cls): + def add_scope(self, cls) -> None: """Add scope for type parameters in annotations.""" annotations = {} for key, val in self.annotations.items(): @@ -176,13 +182,13 @@ def _postprocess_annotation(self, name, annotation): else: return annotation - def set_annotation(self, name, annotation): + def set_annotation(self, name, annotation) -> None: self.annotations[name] = self._postprocess_annotation(name, annotation) - def del_annotation(self, name): + def del_annotation(self, name) -> None: del self.annotations[name] # Raises KeyError if annotation does not exist. - def check_type_parameters(self, stack, opcode, is_attribute_of_class): + def check_type_parameters(self, stack, opcode, is_attribute_of_class) -> None: """Check type parameters in function.""" if not self.annotations: return @@ -291,20 +297,20 @@ def prepend_parameter(self: _SigT, name: str, typ: _base.BaseValue) -> _SigT: annots = {**self.annotations, name: typ} return self._replace(param_names=param_names, annotations=annots) - def mandatory_param_count(self): + def mandatory_param_count(self) -> int: num = len([name for name in self.param_names if name not in self.defaults]) num += len( [name for name in self.kwonly_params if name not in self.defaults] ) return num - def maximum_param_count(self): + def maximum_param_count(self) -> int | None: if self.varargs_name or self.kwargs_name: return None return len(self.param_names) + len(self.kwonly_params) @classmethod - def from_pytd(cls, ctx, name, sig): + def from_pytd(cls: type[_TSignature], ctx, name, sig) -> _TSignature: """Construct an abstract signature from a pytd signature.""" pytd_annotations = [ (p.name, p.type) @@ -345,7 +351,7 @@ def param_to_var(p): ) @classmethod - def from_callable(cls, val): + def from_callable(cls: type[_TSignature], val) -> _TSignature: annotations = { argname(i): val.formal_type_parameters[i] for i in range(val.num_args) } @@ -363,7 +369,9 @@ def from_callable(cls, val): ) @classmethod - def from_param_names(cls, name, param_names, kind=pytd.ParameterKind.REGULAR): + def from_param_names( + cls: type[_TSignature], name, param_names, kind=pytd.ParameterKind.REGULAR + ) -> _TSignature: """Construct a minimal signature from a name and a list of param names.""" names = tuple(param_names) if kind == pytd.ParameterKind.REGULAR: @@ -391,7 +399,7 @@ def from_param_names(cls, name, param_names, kind=pytd.ParameterKind.REGULAR): ) @classmethod - def from_any(cls): + def from_any(cls: type[_TSignature]) -> _TSignature: """Treat `Any` as `f(...) -> Any`.""" return cls( name="", @@ -436,11 +444,11 @@ def insert_varargs_and_kwargs(self, args): ) return self._replace(param_names=new_param_names) - _ATTRIBUTES = set( + _ATTRIBUTES: set = set( __init__.__code__.co_varnames[: __init__.__code__.co_argcount] ) - {"self", "postprocess_annotations"} - def _replace(self, **kwargs): + def _replace(self: _TSignature, **kwargs) -> _TSignature: """Returns a copy of the signature with the specified values replaced.""" assert not set(kwargs) - self._ATTRIBUTES for attr in self._ATTRIBUTES: @@ -449,7 +457,7 @@ def _replace(self, **kwargs): kwargs["postprocess_annotations"] = False return type(self)(**kwargs) - def iter_args(self, args): + def iter_args(self, args) -> Generator[tuple[Any, Any, Any], Any, None]: """Iterates through the given args, attaching names and expected types.""" for i, posarg in enumerate(args.posargs): if i < len(self.param_names): @@ -485,7 +493,7 @@ def iter_args(self, args): self.annotations.get(self.kwargs_name), ) - def check_defaults(self, ctx): + def check_defaults(self, ctx) -> None: """Raises an error if a non-default param follows a default.""" has_default = False for name in self.param_names: @@ -499,7 +507,7 @@ def check_defaults(self, ctx): ctx.errorlog.invalid_function_definition(ctx.vm.stack(), msg) return - def _yield_arguments(self): + def _yield_arguments(self) -> Generator[Any, Any, None]: """Yield all the function arguments.""" names = list(self.param_names) if self.varargs_name: @@ -520,7 +528,7 @@ def _yield_arguments(self): def _print_annot(self, name): return _print(self.annotations[name]) if name in self.annotations else None - def _print_default(self, name): + def _print_default(self, name) -> str | None: if name in self.defaults: values = self.defaults[name].data if len(values) > 1: @@ -530,7 +538,7 @@ def _print_default(self, name): else: return None - def __repr__(self): + def __repr__(self) -> str: args = list(self._yield_arguments()) if self.posonly_count: args = args[: self.posonly_count] + ["/"] + args[self.posonly_count :] @@ -557,7 +565,7 @@ def get_first_arg(self, callargs): return None return callargs.get(name) - def populate_annotation_dict(self, annots, ctx, param_names=None): + def populate_annotation_dict(self, annots, ctx, param_names=None) -> None: """Populate annotation dict with default values.""" if param_names is None: param_names = self.param_names @@ -570,7 +578,7 @@ def populate_annotation_dict(self, annots, ctx, param_names=None): annots[self.kwargs_name] = ctx.convert.dict_type -def _convert_namedargs(namedargs): +def _convert_namedargs(namedargs: _T0) -> dict[None, None] | _T0: return {} if namedargs is None else namedargs @@ -593,13 +601,13 @@ class Args: starargs: cfg.Variable | None = None starstarargs: cfg.Variable | None = None - def has_namedargs(self): + def has_namedargs(self) -> bool: return bool(self.namedargs) - def has_non_namedargs(self): + def has_non_namedargs(self) -> bool: return bool(self.posargs or self.starargs or self.starstarargs) - def is_empty(self): + def is_empty(self) -> bool: return not (self.has_namedargs() or self.has_non_namedargs()) def starargs_as_tuple(self, node, ctx): @@ -636,7 +644,9 @@ def _expand_typed_star(self, node, star, count, ctx): p = ctx.new_unsolvable(node) return [p.AssignToNewVariable(node) for _ in range(count)] - def _unpack_and_match_args(self, node, ctx, match_signature, starargs_tuple): + def _unpack_and_match_args( + self, node, ctx, match_signature, starargs_tuple + ) -> tuple[Any, Any]: """Match args against a signature with unpacking.""" posargs = self.posargs namedargs = self.namedargs @@ -711,7 +721,7 @@ def _unpack_and_match_args(self, node, ctx, match_signature, starargs_tuple): # We have **kwargs but no *args in the invocation return posargs + tuple(pre), None - def simplify(self, node, ctx, match_signature=None): + def simplify(self: _TArgs, node, ctx, match_signature=None) -> _TArgs: """Try to insert part of *args, **kwargs into posargs / namedargs.""" # TODO(rechen): When we have type information about *args/**kwargs, # we need to check it before doing this simplification. @@ -780,7 +790,7 @@ def simplify(self, node, ctx, match_signature=None): simplify(starstarargs), ) - def get_variables(self): + def get_variables(self) -> list: variables = list(self.posargs) + list(self.namedargs.values()) if self.starargs is not None: variables.append(self.starargs) @@ -801,10 +811,10 @@ def delete_namedarg(self, name): new_namedargs = {k: v for k, v in self.namedargs.items() if k != name} return self.replace(namedargs=new_namedargs) - def replace(self, **kwargs): + def replace(self: _TArgs, **kwargs) -> _TArgs: return attrs.evolve(self, **kwargs) - def has_opaque_starargs_or_starstarargs(self): + def has_opaque_starargs_or_starstarargs(self) -> bool: return any( arg and not _isinstance(arg, "PythonConstant") for arg in (self.starargs, self.starstarargs) @@ -814,7 +824,7 @@ def has_opaque_starargs_or_starstarargs(self): class ParamSpecMatch(_base.BaseValue): """Match a paramspec against a sig.""" - def __init__(self, paramspec, sig, ctx): + def __init__(self, paramspec, sig, ctx) -> None: super().__init__("ParamSpecMatch", ctx) self.paramspec = paramspec self.sig = sig @@ -837,14 +847,14 @@ class Mutation: name: str value: cfg.Variable - def __eq__(self, other): + def __eq__(self, other) -> bool: return ( self.instance == other.instance and self.name == other.name and frozenset(self.value.data) == frozenset(other.value.data) ) - def __hash__(self): + def __hash__(self) -> int: return hash((self.instance, self.name, frozenset(self.value.data))) @@ -868,7 +878,7 @@ def get_parameter(self, node, param_name): class AbstractReturnType(_ReturnType): """An abstract return type.""" - def __init__(self, t, ctx): + def __init__(self, t, ctx) -> None: self._type = t self._ctx = ctx @@ -887,7 +897,7 @@ def get_parameter(self, node, param_name): class PyTDReturnType(_ReturnType): """A PyTD return type.""" - def __init__(self, t, subst, sources, ctx): + def __init__(self, t, subst, sources, ctx) -> None: self._type = t self._subst = subst self._sources = sources @@ -902,7 +912,7 @@ def instantiate_parameter(self, node, param_name): instance = abstract_utils.get_atomic_value(instance_var) return instance.get_instance_type_parameter(param_name) - def instantiate(self, node): + def instantiate(self, node) -> tuple[Any, Any]: """Instantiate the pytd return type.""" # Type parameter values, which are instantiated by the matcher, will end up # in the return value. Since the matcher does not call __init__, we need to @@ -939,7 +949,7 @@ def get_parameter(self, node, param_name): return t.get_formal_type_parameter(param_name) -def _splats_to_any(seq, ctx): +def _splats_to_any(seq, ctx) -> tuple: return tuple( ctx.new_unsolvable(ctx.root_node) if abstract_utils.is_var_splat(v) else v for v in seq @@ -954,7 +964,7 @@ def call_function( fallback_to_unsolvable=True, allow_never=False, strict_filter=True, -): +) -> tuple[Any, Any]: """Call a function. Args: @@ -1050,7 +1060,7 @@ def call_function( raise error # pylint: disable=raising-bad-type -def match_all_args(ctx, node, func, args): +def match_all_args(ctx, node, func, args: _T3) -> tuple[_T3, list[None]]: """Call match_args multiple times to find all type errors. Args: @@ -1111,7 +1121,7 @@ def match_all_args(ctx, node, func, args): return args, errors -def has_visible_namedarg(node, args, names): +def has_visible_namedarg(node, args, names) -> bool: # Note: this method should be called judiciously, as HasCombination is # potentially very expensive. namedargs = {args.namedargs[name] for name in names} diff --git a/pytype/abstract/mixin.py b/pytype/abstract/mixin.py index 52611eaa8..64e8d6aaa 100644 --- a/pytype/abstract/mixin.py +++ b/pytype/abstract/mixin.py @@ -10,7 +10,7 @@ from pytype.typegraph import cfg from pytype.types import types -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) _isinstance = abstract_utils._isinstance # pylint: disable=protected-access _make = abstract_utils._make # pylint: disable=protected-access @@ -21,7 +21,7 @@ class MixinMeta(type): __mixin_overloads__: dict[str, type[Any]] _HAS_DYNAMIC_ATTRIBUTES = True - def __init__(cls, name, superclasses, *args, **kwargs): + def __init__(cls: "MixinMeta", name, superclasses, *args, **kwargs) -> None: super().__init__(name, superclasses, *args, **kwargs) for sup in superclasses: if "overloads" in sup.__dict__: @@ -35,7 +35,7 @@ def __init__(cls, name, superclasses, *args, **kwargs): else: setattr(cls, "__mixin_overloads__", {method: sup}) - def super(cls, method): + def super(cls: "MixinMeta", method): """Imitate super() in a mix-in. This method is a substitute for @@ -78,15 +78,15 @@ class PythonConstant(types.PythonConstant, metaclass=MixinMeta): "r" etc.). """ - overloads = ("__repr__",) + overloads: tuple[str] = ("__repr__",) - def init_mixin(self, pyval): + def init_mixin(self, pyval) -> None: """Mix-in equivalent of __init__.""" self.pyval = pyval self.is_concrete = True self._printing = False - def str_of_constant(self, printer): + def str_of_constant(self, printer) -> str: """Get a string representation of this constant. Args: @@ -99,7 +99,7 @@ def str_of_constant(self, printer): del printer return repr(self.pyval) - def __repr__(self): + def __repr__(self) -> str: if self._printing: # recursion detected const = "[...]" else: @@ -116,13 +116,13 @@ class HasSlots(metaclass=MixinMeta): handling of some magic methods (__setitem__ etc.) """ - overloads = ("get_special_attribute",) + overloads: tuple[str] = ("get_special_attribute",) - def init_mixin(self): + def init_mixin(self) -> None: self._slots = {} self._super = {} - def set_slot(self, name, slot): + def set_slot(self, name, slot) -> None: """Add a new slot to this value.""" assert name not in self._slots, f"slot {name} already occupied" # For getting a slot value, we don't need a ParameterizedClass's type @@ -135,11 +135,11 @@ def set_slot(self, name, slot): self._super[name] = attr self._slots[name] = slot - def set_native_slot(self, name, method): + def set_native_slot(self, name, method) -> None: """Add a new NativeFunction slot to this value.""" self.set_slot(name, _make("NativeFunction", name, method, self.ctx)) - def call_pytd(self, node, name, *args): + def call_pytd(self, node, name, *args) -> tuple[Any, Any]: """Call the (original) pytd version of a method we overwrote.""" return function.call_function( self.ctx, @@ -175,9 +175,9 @@ class NestedAnnotation(metaclass=MixinMeta): one but with the given inner types, again as a (key, typ) sequence. """ - overloads = ("formal",) + overloads: tuple[str] = ("formal",) - def init_mixin(self): + def init_mixin(self) -> None: self.processed = False self._seen_for_formal = False # for calculating the 'formal' property self._formal = None @@ -228,7 +228,7 @@ class LazyMembers(metaclass=MixinMeta): members: dict[str, cfg.Variable] - def init_mixin(self, member_map): + def init_mixin(self, member_map) -> None: self._member_map = member_map def _convert_member(self, name, member, subst=None): @@ -266,15 +266,18 @@ class PythonDict(PythonConstant): # More methods can be implemented by adding the name to `overloads` and # defining the delegating method. - overloads = PythonConstant.overloads + ( - "__getitem__", - "get", - "__contains__", - "copy", - "__iter__", - "items", - "keys", - "values", + overloads: tuple[str, str, str, str, str, str, str, str, str] = ( + PythonConstant.overloads + + ( + "__getitem__", + "get", + "__contains__", + "copy", + "__iter__", + "items", + "keys", + "values", + ) ) def __getitem__(self, key): @@ -283,7 +286,7 @@ def __getitem__(self, key): def get(self, key, default=None): return self.pyval.get(key, default) - def __contains__(self, key): + def __contains__(self, key) -> bool: return key in self.pyval def copy(self): diff --git a/pytype/analyze.py b/pytype/analyze.py index 3fa2122bc..094ca0cef 100644 --- a/pytype/analyze.py +++ b/pytype/analyze.py @@ -12,7 +12,7 @@ from pytype.pytd import pytd_utils from pytype.pytd import visitors -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) # How deep to follow call chains: INIT_MAXIMUM_DEPTH = 4 # during module loading @@ -118,7 +118,7 @@ def infer_types( return Analysis(ctx, ast, deps_pytd) -def _maybe_output_debug(options, program): +def _maybe_output_debug(options, program) -> None: """Maybe emit debugging output.""" if options.output_debug: text = debug.program_to_text(program) diff --git a/pytype/annotation_utils.py b/pytype/annotation_utils.py index cb9ed98c7..e7f986b84 100644 --- a/pytype/annotation_utils.py +++ b/pytype/annotation_utils.py @@ -1,10 +1,9 @@ """Utilities for inline type annotations.""" import collections -from collections.abc import Sequence import dataclasses import itertools -from typing import Any +from typing import Generator, Sequence, TypeVar, Any from pytype import state from pytype import utils @@ -16,6 +15,8 @@ from pytype.pytd import pytd_utils from pytype.typegraph import cfg +_T0 = TypeVar("_T0") + @dataclasses.dataclass class AnnotatedValue: @@ -191,7 +192,7 @@ def get_type_parameter_subst( for name, annot in annotations.items() } - def get_late_annotations(self, annot): + def get_late_annotations(self, annot: _T0) -> Generator[_T0, Any, None]: if annot.is_late_annotation() and not annot.resolved: yield annot elif isinstance(annot, mixin.NestedAnnotation): @@ -323,7 +324,9 @@ def convert_function_annotations(self, node, raw_annotations): else: return {} - def convert_annotations_list(self, node, annotations_list): + def convert_annotations_list( + self, node, annotations_list + ) -> dict[Any, abstract.BaseValue]: """Convert a (name, raw_annot) list to a {name: annotation} dict.""" annotations = {} for name, t in annotations_list: @@ -337,7 +340,7 @@ def convert_annotations_list(self, node, annotations_list): annotations[name] = annot return annotations - def convert_class_annotations(self, node, raw_annotations): + def convert_class_annotations(self, node, raw_annotations) -> dict: """Convert a name -> raw_annot dict to annotations.""" annotations = {} raw_items = raw_annotations.items() @@ -350,7 +353,9 @@ def convert_class_annotations(self, node, raw_annotations): annotations[name] = annot or self.ctx.convert.unsolvable return annotations - def init_annotation(self, node, name, annot, container=None, extra_key=None): + def init_annotation( + self, node: _T0, name, annot, container=None, extra_key=None + ) -> tuple[_T0, Any]: value = self.ctx.vm.init_class( node, annot, container=container, extra_key=extra_key ) @@ -404,7 +409,7 @@ def extract_and_init_annotation(self, node, name, var): return typ, self.ctx.new_unsolvable(node) return self._sub_and_instantiate(node, name, typ, substs) - def _sub_and_instantiate(self, node, name, typ, substs): + def _sub_and_instantiate(self, node, name, typ, substs) -> tuple[Any, Any]: if isinstance(typ, abstract.FinalAnnotation): t, value = self._sub_and_instantiate(node, name, typ.annotation, substs) return abstract.FinalAnnotation(t, self.ctx), value @@ -426,7 +431,7 @@ def _sub_and_instantiate(self, node, name, typ, substs): _, value = self.init_annotation(node, name, type_for_value) return substituted_type, value - def apply_annotation(self, node, op, name, value): + def apply_annotation(self, node, op, name, value) -> AnnotatedValue: """If there is an annotation for the op, return its value.""" assert op is self.ctx.vm.frame.current_opcode if op.code.filename != self.ctx.vm.filename: @@ -506,7 +511,7 @@ def extract_annotation( return self.ctx.convert.unsolvable return typ - def _log_illegal_params(self, illegal_params, stack, typ, name): + def _log_illegal_params(self, illegal_params, stack, typ, name) -> None: out_of_scope_params = utils.unique_list(illegal_params) details = "TypeVar(s) %s not in scope" % ", ".join( repr(p) for p in out_of_scope_params @@ -525,7 +530,7 @@ def _log_illegal_params(self, illegal_params, stack, typ, name): details += f"\nNote: For all string types, use {str_type}." self.ctx.errorlog.invalid_annotation(stack, typ, details, name) - def eval_multi_arg_annotation(self, node, func, annot, stack): + def eval_multi_arg_annotation(self, node, func, annot, stack) -> None: """Evaluate annotation for multiple arguments (from a type comment).""" args, errorlog = self._eval_expr_as_tuple(node, annot, stack) if errorlog: @@ -571,7 +576,7 @@ def _process_one_annotation( ) -> abstract.BaseValue | None: """Change annotation / record errors where required.""" if isinstance(annotation, abstract.AnnotationContainer): - annotation = annotation.base_cls + annotation = annotation.base_cls # pytype: disable=attribute-error if isinstance(annotation, typing_overlay.Union): self.ctx.errorlog.invalid_annotation( @@ -672,7 +677,9 @@ def _process_one_annotation( ) return None - def _eval_expr_as_tuple(self, node, expr, stack): + def _eval_expr_as_tuple( + self, node, expr, stack + ) -> tuple[tuple, abstract_utils.EvaluationError | None]: """Evaluate an expression as a tuple.""" if not expr: return (), None diff --git a/pytype/attribute.py b/pytype/attribute.py index 041b3a4f2..1f2474b87 100644 --- a/pytype/attribute.py +++ b/pytype/attribute.py @@ -1,7 +1,7 @@ """Abstract attribute handling.""" import logging -from typing import Optional +from typing import Any, TypeVar, Optional from pytype import datatypes from pytype import utils @@ -14,7 +14,9 @@ from pytype.overlays import special_builtins from pytype.typegraph import cfg -log = logging.getLogger(__name__) +_T0 = TypeVar("_T0") + +log: logging.Logger = logging.getLogger(__name__) _NodeAndMaybeVarType = tuple[cfg.CFGNode, Optional[cfg.Variable]] @@ -178,23 +180,23 @@ def set_attribute( return node elif isinstance(obj, abstract.TypeParameterInstance): nodes = [] - for v in obj.instance.get_instance_type_parameter(obj.name).data: + for v in obj.instance.get_instance_type_parameter(obj.name).data: # pytype: disable=attribute-error nodes.append(self.set_attribute(node, v, name, value)) return self.ctx.join_cfg_nodes(nodes) if nodes else node elif isinstance(obj, abstract.Union): - for option in obj.options: + for option in obj.options: # pytype: disable=attribute-error node = self.set_attribute(node, option, name, value) return node else: raise NotImplementedError(obj.__class__.__name__) - def _check_writable(self, obj, name): + def _check_writable(self, obj, name) -> bool: """Verify that a given attribute is writable. Log an error if not.""" if not obj.cls.mro: # "Any" etc. return True for baseclass in obj.cls.mro: - if baseclass.full_name == "builtins.object": + if baseclass.full_name == "builtins.object": # pytype: disable=attribute-error # It's not possible to set an attribute on object itself. # (object has __setattr__, but that honors __slots__.) continue @@ -202,7 +204,7 @@ def _check_writable(self, obj, name): "__setattr__" in baseclass or name in baseclass ): return True # This is a programmatic attribute. - if baseclass.slots is None or name in baseclass.slots: + if baseclass.slots is None or name in baseclass.slots: # pytype: disable=attribute-error return True # Found a slot declaration; this is an instance attribute self.ctx.errorlog.not_writable(self.ctx.vm.frames, obj, name) return False @@ -300,7 +302,7 @@ def _get_instance_attribute( ) return node, self.ctx.new_unsolvable(node) - def _get_attribute(self, node, obj, cls, name, valself): + def _get_attribute(self, node, obj, cls, name, valself) -> tuple[Any, Any]: """Get an attribute from an object or its class. The underlying method called by all of the (_)get_(x_)attribute methods. @@ -414,7 +416,7 @@ def _get_attribute_from_super_instance( def _lookup_from_mro_and_handle_descriptors( self, node, cls, name, valself, skip - ): + ) -> tuple[Any, Any]: attr = self._lookup_from_mro(node, cls, name, valself, skip) if not attr.bindings: return node, None @@ -450,7 +452,7 @@ def _lookup_from_mro_and_handle_descriptors( return self.ctx.join_cfg_nodes(nodes), result return node, attr - def _computable(self, name): + def _computable(self, name) -> bool: return not (name.startswith("__") and name.endswith("__")) def _get_attribute_computed( @@ -481,7 +483,9 @@ def _get_attribute_computed( ) return node, None - def _lookup_variable_annotation(self, node, base, name, valself): + def _lookup_variable_annotation( + self, node, base, name, valself + ) -> tuple[Any, Any]: if not isinstance(base, abstract.Class): return None, None annots = abstract_utils.get_annotations_dict(base.members) @@ -603,7 +607,7 @@ def _get_attribute_flat(self, node, cls, name, valself): else: return node, None - def _get_member(self, node, obj, name, valself): + def _get_member(self, node: _T0, obj, name, valself) -> tuple[_T0, Any]: """Get a member of an object.""" if isinstance(obj, mixin.LazyMembers): if not valself: diff --git a/pytype/block_environment.py b/pytype/block_environment.py index 27854b608..cfda8d9fb 100644 --- a/pytype/block_environment.py +++ b/pytype/block_environment.py @@ -13,16 +13,16 @@ class Environment: """A store of local variables per blockgraph node.""" - def __init__(self): + def __init__(self) -> None: self.block_locals: BlockLocals = {} # Blocks whose outgoing edges cannot be traversed. This can happen if, for # example, a block unconditionally raises an exception. self._dead_ends: set[blocks.Block] = set() - def mark_dead_end(self, block): + def mark_dead_end(self, block) -> None: self._dead_ends.add(block) - def add_block(self, frame, block): + def add_block(self, frame, block) -> None: """Add a new block and initialize its locals.""" local = {} @@ -59,8 +59,8 @@ def add_block(self, frame, block): var |= set(incoming_locals[k]) local[k] = list(var) - def store_local(self, block, name, var): + def store_local(self, block, name, var) -> None: self.block_locals[block][name] = [var] - def get_local(self, block, name): + def get_local(self, block, name) -> list[cfg.Variable] | None: return self.block_locals[block].get(name) diff --git a/pytype/blocks/block_serializer.py b/pytype/blocks/block_serializer.py index a62446d70..a0baa601c 100644 --- a/pytype/blocks/block_serializer.py +++ b/pytype/blocks/block_serializer.py @@ -43,7 +43,7 @@ class SerializedCode: class BlockGraphEncoder(json.JSONEncoder): """Implements the JSONEncoder behavior for ordered bytecode blocks.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def _encode_code(self, code: SerializedCode) -> dict[str, Any]: @@ -70,7 +70,7 @@ def default(self, o): return super().default(o) -def encode_merged_graph(block_graph): +def encode_merged_graph(block_graph) -> str: out = [] for k, v in block_graph.graph.items(): for b in v.order: diff --git a/pytype/blocks/blocks.py b/pytype/blocks/blocks.py index 100b6c5e5..8f3d423c2 100644 --- a/pytype/blocks/blocks.py +++ b/pytype/blocks/blocks.py @@ -1,7 +1,7 @@ """Functions for computing the execution order of bytecode.""" from collections.abc import Iterator -from typing import Any, cast +from typing import Any, cast, TypeVar from pycnite import bytecode as pyc_bytecode from pycnite import marshal as pyc_marshal import pycnite.types @@ -9,7 +9,15 @@ from pytype.typegraph import cfg_utils from typing_extensions import Self -STORE_OPCODES = ( +_SelfBlock = TypeVar("_SelfBlock", bound="Block") + +STORE_OPCODES: tuple[ + type[opcodes.STORE_NAME], + type[opcodes.STORE_FAST], + type[opcodes.STORE_ATTR], + type[opcodes.STORE_DEREF], + type[opcodes.STORE_GLOBAL], +] = ( opcodes.STORE_NAME, opcodes.STORE_FAST, opcodes.STORE_ATTR, @@ -17,7 +25,13 @@ opcodes.STORE_GLOBAL, ) -_NOOP_OPCODES = (opcodes.NOP, opcodes.PRECALL, opcodes.RESUME) +_NOOP_OPCODES: tuple[ + type[opcodes.NOP], type[opcodes.PRECALL], type[opcodes.RESUME] +] = ( + opcodes.NOP, + opcodes.PRECALL, + opcodes.RESUME, +) class _Locals311: @@ -28,7 +42,7 @@ class _Locals311: CO_FAST_CELL = 0x40 CO_FAST_FREE = 0x80 - def __init__(self, code: pycnite.types.CodeType311): + def __init__(self, code: pycnite.types.CodeType311) -> None: table = list(zip(code.co_localsplusnames, code.co_localspluskinds)) filter_names = lambda k: tuple(name for name, kind in table if kind & k) self.co_varnames = filter_names(self.CO_FAST_LOCAL) @@ -67,16 +81,16 @@ def connect_outgoing(self, target: Self): self.outgoing.add(target) target.incoming.add(self) - def __str__(self): + def __str__(self) -> str: return "" % self.id - def __repr__(self): + def __repr__(self) -> str: return "" % (self.id, self.code) def __getitem__(self, index_or_slice): return self.code.__getitem__(index_or_slice) - def __iter__(self): + def __iter__(self) -> Iterator[opcodes.Opcode]: return self.code.__iter__() @@ -164,7 +178,7 @@ def __init__( for insn in bytecode: insn.code = self - def __repr__(self): + def __repr__(self) -> str: return f"OrderedCode({self.qualname}, version={self.python_version})" @property @@ -178,40 +192,40 @@ def co_consts(self): def code_iter(self) -> Iterator[opcodes.Opcode]: return (op for block in self.order for op in block) # pylint: disable=g-complex-comprehension - def get_first_opcode(self, skip_noop=False): + def get_first_opcode(self, skip_noop=False) -> opcodes.Opcode: for op in self.code_iter: if not skip_noop or not isinstance(op, _NOOP_OPCODES): return op assert False, "OrderedCode should have at least one opcode" - def has_opcode(self, op_type): + def has_opcode(self, op_type) -> bool: return any(isinstance(op, op_type) for op in self.code_iter) - def has_iterable_coroutine(self): + def has_iterable_coroutine(self) -> bool: return bool(self._co_flags & pyc_marshal.Flags.CO_ITERABLE_COROUTINE) - def set_iterable_coroutine(self): + def set_iterable_coroutine(self) -> None: self._co_flags |= pyc_marshal.Flags.CO_ITERABLE_COROUTINE - def has_coroutine(self): + def has_coroutine(self) -> bool: return bool(self._co_flags & pyc_marshal.Flags.CO_COROUTINE) - def has_generator(self): + def has_generator(self) -> bool: return bool(self._co_flags & pyc_marshal.Flags.CO_GENERATOR) - def has_async_generator(self): + def has_async_generator(self) -> bool: return bool(self._co_flags & pyc_marshal.Flags.CO_ASYNC_GENERATOR) - def has_varargs(self): + def has_varargs(self) -> bool: return bool(self._co_flags & pyc_marshal.Flags.CO_VARARGS) - def has_varkeywords(self): + def has_varkeywords(self) -> bool: return bool(self._co_flags & pyc_marshal.Flags.CO_VARKEYWORDS) - def has_newlocals(self): + def has_newlocals(self) -> bool: return bool(self._co_flags & pyc_marshal.Flags.CO_NEWLOCALS) - def get_arg_count(self): + def get_arg_count(self) -> int: """Total number of arg names including '*args' and '**kwargs'.""" count = self.argcount + self.kwonlyargcount if self.has_varargs(): @@ -228,13 +242,13 @@ def get_cell_index(self, name): class BlockGraph: """CFG made up of ordered code blocks.""" - def __init__(self): + def __init__(self) -> None: self.graph: dict[opcodes.Opcode, OrderedCode] = {} - def add(self, ordered_code: OrderedCode): + def add(self, ordered_code: OrderedCode) -> None: self.graph[ordered_code.get_first_opcode()] = ordered_code - def pretty_print(self): + def pretty_print(self) -> str: return str(self.graph) diff --git a/pytype/blocks/process_blocks.py b/pytype/blocks/process_blocks.py index 7aa9c4591..510d1a3cc 100644 --- a/pytype/blocks/process_blocks.py +++ b/pytype/blocks/process_blocks.py @@ -1,15 +1,18 @@ """Analyze code blocks and process opcodes.""" +from typing import TypeVar from pytype.blocks import blocks from pytype.pyc import opcodes from pytype.pyc import pyc +_T0 = TypeVar("_T0") + # Opcodes whose argument can be a block of code. -CODE_LOADING_OPCODES = (opcodes.LOAD_CONST,) +CODE_LOADING_OPCODES: tuple[type[opcodes.LOAD_CONST]] = (opcodes.LOAD_CONST,) -def _is_function_def(fn_code): +def _is_function_def(fn_code) -> bool: """Helper function for CollectFunctionTypeCommentTargetsVisitor.""" # Reject anything that is not a named function (e.g. ). first = fn_code.name[0] @@ -28,7 +31,7 @@ def _is_function_def(fn_code): class CollectAnnotationTargetsVisitor(pyc.CodeVisitor): """Collect opcodes that might have annotations attached.""" - def __init__(self): + def __init__(self) -> None: super().__init__() # A mutable map of line: opcode for STORE_* opcodes. This is modified as the # visitor runs, and contains the last opcode for each line. @@ -38,7 +41,7 @@ def __init__(self): # contain function type comments. self.make_function_ops = {} - def visit_code(self, code): + def visit_code(self, code: _T0) -> _T0: """Find STORE_* and MAKE_FUNCTION opcodes for attaching annotations.""" # Offset between function code and MAKE_FUNCTION # [LOAD_CONST , LOAD_CONST , MAKE_FUNCTION] @@ -93,11 +96,11 @@ def visit_code(self, code): class FunctionDefVisitor(pyc.CodeVisitor): """Add metadata to function definition opcodes.""" - def __init__(self, param_annotations): + def __init__(self, param_annotations) -> None: super().__init__() self.annots = param_annotations - def visit_code(self, code): + def visit_code(self, code: _T0) -> _T0: for op in code.code_iter: if isinstance(op, opcodes.MAKE_FUNCTION): if op.line in self.annots: @@ -145,7 +148,7 @@ def merge_annotations(code, annotations, param_annotations): return code -def adjust_returns(code, block_returns): +def adjust_returns(code, block_returns) -> None: """Adjust line numbers for return statements in with blocks.""" rets = {k: iter(v) for k, v in block_returns} diff --git a/pytype/compare.py b/pytype/compare.py index e32dcc711..25a12a6ab 100644 --- a/pytype/compare.py +++ b/pytype/compare.py @@ -5,13 +5,13 @@ from pytype.pytd import slots # Equality classes. -NUMERIC = frozenset( +NUMERIC: frozenset[str] = frozenset( {"builtins.bool", "builtins.int", "builtins.float", "builtins.complex"} ) -STRING = frozenset({"builtins.str", "builtins.unicode"}) +STRING: frozenset[str] = frozenset({"builtins.str", "builtins.unicode"}) # Fully qualified names of types that are parameterized containers. -_CONTAINER_NAMES = frozenset( +_CONTAINER_NAMES: frozenset[str] = frozenset( {"builtins.list", "builtins.set", "builtins.frozenset"} ) @@ -20,7 +20,7 @@ class CmpTypeError(Exception): """Comparing incompatible primitive constants.""" -def _incompatible(left_name, right_name): +def _incompatible(left_name, right_name) -> bool: """Incompatible primitive types can never be equal.""" if left_name == right_name: return False @@ -30,13 +30,13 @@ def _incompatible(left_name, right_name): return True -def _is_primitive_constant(ctx, value): +def _is_primitive_constant(ctx, value) -> bool: if isinstance(value, abstract.PythonConstant): return value.pyval.__class__ in ctx.convert.primitive_classes return False -def _is_primitive(ctx, value): +def _is_primitive(ctx, value) -> bool: if _is_primitive_constant(ctx, value): return True elif isinstance(value, abstract.Instance): @@ -44,7 +44,7 @@ def _is_primitive(ctx, value): return False -def _is_equality_cmp(op): +def _is_equality_cmp(op) -> bool: return op in (slots.EQ, slots.NE) @@ -183,7 +183,7 @@ def _compare_dict(op, left, right): return None -def _compare_class(op, left, right): +def _compare_class(op, left, right) -> None: del right # unused # Classes without a custom metaclass are not orderable. if left.cls.full_name != "builtins.type": diff --git a/pytype/config.py b/pytype/config.py index b7b87e85a..c4a030f6b 100644 --- a/pytype/config.py +++ b/pytype/config.py @@ -9,7 +9,7 @@ import logging import os import sys -from typing import Literal +from typing import Callable, Optional, Literal from typing import overload from pytype import datatypes @@ -22,7 +22,7 @@ from pytype.typegraph import cfg_utils -LOG_LEVELS = [ +LOG_LEVELS: list[int] = [ logging.CRITICAL, logging.ERROR, logging.WARNING, @@ -30,9 +30,11 @@ logging.DEBUG, ] -uses = utils.AnnotatingDecorator() # model relationship between options +uses: utils.AnnotatingDecorator = ( + utils.AnnotatingDecorator() +) # model relationship between options -_LIBRARY_ONLY_OPTIONS = { +_LIBRARY_ONLY_OPTIONS: dict[str, Optional[Callable]] = { # a custom file opening function that will be used in place of builtins.open "open_function": open, # Imports map as a list of tuples. @@ -57,7 +59,11 @@ def __init__( ): ... - def __init__(self, argv_or_options, command_line=False): + def __init__( + self, + argv_or_options: argparse.Namespace, + command_line: Literal[False] = False, + ): """Parse and encapsulate the configuration options. Also sets up some basic logger configuration. @@ -117,19 +123,19 @@ def create(cls, input_filename=None, **kwargs): setattr(options, k, v) return cls(options) - def tweak(self, **kwargs): + def tweak(self, **kwargs) -> None: for k, v in kwargs.items(): assert hasattr(self, k) # Don't allow adding arbitrary junk setattr(self, k, v) - def set_feature_flags(self, flags): + def set_feature_flags(self, flags) -> None: updates = {f.dest: True for f in FEATURE_FLAGS if f.flag in flags} self.tweak(**updates) def as_dict(self): return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} - def __repr__(self): + def __repr__(self) -> str: return "\n".join([f"{k}: {v!r}" for k, v in sorted(self.as_dict().items())]) @@ -140,7 +146,7 @@ def make_parser(): return o -def base_parser(): +def base_parser() -> datatypes.ParserWrapper: """Use argparse to make a parser for configuration options.""" parser = argparse.ArgumentParser( usage="%(prog)s [options] input", @@ -149,7 +155,7 @@ def base_parser(): return datatypes.ParserWrapper(parser) -def add_all_pytype_options(o): +def add_all_pytype_options(o) -> None: """Add all pytype options to the given parser.""" # Input files o.add_argument("input", nargs="*", help="File to process") @@ -169,11 +175,11 @@ def add_all_pytype_options(o): class _Arg: """Hold args for argparse.ArgumentParser.add_argument.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: self.args = args self.kwargs = kwargs - def add_to(self, parser): + def add_to(self, parser) -> None: parser.add_argument(*self.args, **self.kwargs) def get(self, k): @@ -192,19 +198,19 @@ def dest(self): return self.kwargs["dest"] -def _flag(opt, default, help_text): +def _flag(opt, default, help_text) -> _Arg: dest = opt.lstrip("-").replace("-", "_") return _Arg( opt, dest=dest, default=default, help=help_text, action="store_true" ) -def add_options(o, arglist): +def add_options(o, arglist) -> None: for arg in arglist: arg.add_to(o) -MODES = [ +MODES: list[_Arg] = [ _Arg( "-C", "--check", @@ -225,7 +231,7 @@ def add_options(o, arglist): ] -BASIC_OPTIONS = [ +BASIC_OPTIONS: list[_Arg] = [ _Arg( "-d", "--disable", @@ -261,7 +267,7 @@ def add_options(o, arglist): ] -_OPT_IN_FEATURES = [ +_OPT_IN_FEATURES: list[_Arg] = [ # Feature flags that are not experimental, but are too strict to default # to True and are therefore left as opt-in features for users to enable. _flag("--no-return-any", False, "Do not allow Any as a return type."), @@ -274,7 +280,7 @@ def add_options(o, arglist): ] -FEATURE_FLAGS = [ +FEATURE_FLAGS: list[_Arg] = [ _flag( "--bind-decorated-methods", False, @@ -298,7 +304,7 @@ def add_options(o, arglist): ] + _OPT_IN_FEATURES -EXPERIMENTAL_FLAGS = [ +EXPERIMENTAL_FLAGS: list[_Arg] = [ _flag( "--precise-return", False, @@ -340,7 +346,7 @@ def add_options(o, arglist): ] -SUBTOOLS = [ +SUBTOOLS: list[_Arg] = [ _Arg( "--generate-builtins", action="store", @@ -358,7 +364,7 @@ def add_options(o, arglist): ] -PICKLE_OPTIONS = [ +PICKLE_OPTIONS: list[_Arg] = [ _Arg( "--pickle-output", action="store_true", @@ -404,7 +410,7 @@ def add_options(o, arglist): ] -INFRASTRUCTURE_OPTIONS = [ +INFRASTRUCTURE_OPTIONS: list[_Arg] = [ _Arg( "--imports_info", type=str, @@ -538,7 +544,7 @@ def add_options(o, arglist): ] -DEBUG_OPTIONS = [ +DEBUG_OPTIONS: list[_Arg] = [ _Arg( "--check_preconditions", action="store_true", @@ -686,7 +692,7 @@ def add_options(o, arglist): ] -ALL_OPTIONS = ( +ALL_OPTIONS: list[_Arg] = ( MODES + BASIC_OPTIONS + SUBTOOLS @@ -703,17 +709,17 @@ def args_map(): return {x.get("dest"): x for x in ALL_OPTIONS} -def add_modes(o): +def add_modes(o) -> None: """Add operation modes to the given parser.""" add_options(o, MODES) -def add_basic_options(o): +def add_basic_options(o) -> None: """Add basic options to the given parser.""" add_options(o, BASIC_OPTIONS) -def add_feature_flags(o): +def add_feature_flags(o) -> None: """Add flags for experimental and temporarily gated features.""" def flag(arg, temporary, experimental): @@ -741,26 +747,26 @@ def flag(arg, temporary, experimental): flag(arg, False, True) -def add_subtools(o): +def add_subtools(o) -> None: """Add subtools to the given parser.""" # TODO(rechen): These should be standalone tools. o = o.add_argument_group("subtools") add_options(o, SUBTOOLS) -def add_pickle_options(o): +def add_pickle_options(o) -> None: """Add options for using pickled pyi files to the given parser.""" o = o.add_argument_group("pickle arguments") add_options(o, PICKLE_OPTIONS) -def add_infrastructure_options(o): +def add_infrastructure_options(o) -> None: """Add infrastructure options to the given parser.""" o = o.add_argument_group("infrastructure arguments") add_options(o, INFRASTRUCTURE_OPTIONS) -def add_debug_options(o): +def add_debug_options(o) -> None: """Add debug options to the given parser.""" o = o.add_argument_group("debug arguments") add_options(o, DEBUG_OPTIONS) @@ -773,14 +779,16 @@ class PostprocessingError(Exception): class Postprocessor: """Postprocesses configuration options.""" - def __init__(self, names, opt_map, input_options, output_options=None): + def __init__( + self, names, opt_map, input_options, output_options=None + ) -> None: self.names = names self.opt_map = opt_map self.input_options = input_options # If output not specified, process in-place. self.output_options = output_options or input_options - def process(self): + def process(self) -> None: """Postprocesses all options in self.input_options. This will iterate through all options in self.input_options and make them @@ -847,7 +855,7 @@ def _display_opt(self, opt): else: return self.opt_map[opt] - def _check_exclusive(self, name, value, existing): + def _check_exclusive(self, name, value, existing) -> None: """Check for argument conflicts.""" if existing in _LIBRARY_ONLY_OPTIONS: # Library-only options are often used as an alternate way of setting a @@ -865,7 +873,7 @@ def _check_exclusive(self, name, value, existing): opt = self._display_opt(existing) self.error(f"Not allowed with {opt}", name) - def _check_required(self, name, value, existing): + def _check_required(self, name, value, existing) -> None: """Check for required args.""" if value and not getattr(self.output_options, existing, None): opt = self._display_opt(existing) @@ -928,7 +936,7 @@ def _store_verbosity(self, verbosity): self.error(f"invalid --verbosity: {verbosity}") self.output_options.verbosity = verbosity - def _store_pythonpath(self, pythonpath): + def _store_pythonpath(self, pythonpath) -> None: # Note that the below gives [""] for "", and ["x", ""] for "x:" # ("" is a valid entry to denote the current directory) self.output_options.pythonpath = pythonpath.split(os.pathsep) @@ -957,7 +965,7 @@ def _store_python_version(self, python_version): except compiler.PythonNotFoundError: self.error("Need a valid python%d.%d executable in $PATH" % version) - def _store_disable(self, disable): + def _store_disable(self, disable) -> None: if disable: self.output_options.disable = disable.split(",") else: @@ -1006,10 +1014,10 @@ def _store_output_errors_csv(self, output_errors_csv): self.error("Not allowed with --no-report-errors", "output-errors-csv") self.output_options.output_errors_csv = output_errors_csv - def _store_exec_log(self, exec_log): + def _store_exec_log(self, exec_log) -> None: self.output_options.exec_log = exec_log - def _store_color(self, color): + def _store_color(self, color) -> None: if color not in ("always", "auto", "never"): raise ValueError( f"--color flag allows only 'always', 'auto' or 'never', not {color!r}" @@ -1035,7 +1043,7 @@ def _store_analyze_annotated(self, analyze_annotated): analyze_annotated = self.output_options.check self.output_options.analyze_annotated = analyze_annotated - def _parse_arguments(self, arguments): + def _parse_arguments(self, arguments) -> tuple | None: """Parse the input/output arguments.""" if len(arguments) > 1: self.error("Can only process one file at a time.") @@ -1053,14 +1061,14 @@ def _parse_arguments(self, arguments): % (item, os.pathsep) ) - def _store_pickle_metadata(self, pickle_metadata): + def _store_pickle_metadata(self, pickle_metadata) -> None: if pickle_metadata: self.output_options.pickle_metadata = pickle_metadata.split(",") else: self.output_options.pickle_metadata = [] -def _set_verbosity(verbosity, timestamp_logs, debug_logs): +def _set_verbosity(verbosity, timestamp_logs, debug_logs) -> None: """Set the logging verbosity.""" if verbosity >= 0: basic_logging_level = LOG_LEVELS[verbosity] diff --git a/pytype/constant_folding.py b/pytype/constant_folding.py index 8c9463c3a..3daa55f7f 100644 --- a/pytype/constant_folding.py +++ b/pytype/constant_folding.py @@ -23,7 +23,7 @@ input/output. """ -from typing import Any +from typing import TypeVar, Union, Any import attrs from pycnite import marshal as pyc_marshal @@ -31,6 +31,8 @@ from pytype.pyc import opcodes from pytype.pyc import pyc +_T0 = TypeVar('_T0') + # Copied from typegraph/cfg.py # If we have more than 64 elements in a map/list, the type variable accumulates @@ -43,7 +45,7 @@ class ConstantError(Exception): """Errors raised during constant folding.""" - def __init__(self, message, op): + def __init__(self, message, op) -> None: super().__init__(message) self.lineno = op.line self.message = message @@ -82,7 +84,7 @@ class _Constant: def tag(self): return self.typ[0] - def __repr__(self): + def __repr__(self) -> str: return repr(self.value) @@ -109,17 +111,17 @@ class _Map: class _CollectionBuilder: """Build up a collection of constants.""" - def __init__(self): + def __init__(self) -> None: self.types = set() self.values = [] self.elements = [] - def add(self, constant): + def add(self, constant) -> None: self.types.add(constant.typ) self.elements.append(constant) self.values.append(constant.value) - def build(self): + def build(self) -> _Collection: return _Collection( types=frozenset(self.types), values=tuple(reversed(self.values)), @@ -130,21 +132,21 @@ def build(self): class _MapBuilder: """Build up a map of constants.""" - def __init__(self): + def __init__(self) -> None: self.key_types = set() self.value_types = set() self.keys = [] self.values = [] self.elements = {} - def add(self, key, value): + def add(self, key, value) -> None: self.key_types.add(key.typ) self.value_types.add(value.typ) self.keys.append(key.value) self.values.append(value.value) self.elements[key.value] = value - def build(self): + def build(self) -> _Map: return _Map( key_types=frozenset(self.key_types), keys=tuple(reversed(self.keys)), @@ -157,27 +159,27 @@ def build(self): class _Stack: """A simple opcode stack.""" - def __init__(self): + def __init__(self) -> None: self.stack = [] self.consts = {} def __iter__(self): return self.stack.__iter__() - def push(self, val): + def push(self, val) -> None: self.stack.append(val) def pop(self): return self.stack.pop() - def _preserve_constant(self, c): + def _preserve_constant(self, c) -> None: if c and ( not isinstance(c.op, opcodes.LOAD_CONST) or isinstance(c.op, opcodes.BUILD_STRING) ): self.consts[id(c.op)] = c - def clear(self): + def clear(self) -> None: # Preserve any constants in the stack before clearing it. for c in self.stack: self._preserve_constant(c) @@ -236,7 +238,7 @@ def build_str(self, n, op): self.push(None) return ret - def build(self, python_type, op): + def build(self, python_type, op) -> None: """Build a folded type.""" collection = self.fold_args(op.arg, op) if collection: @@ -253,10 +255,10 @@ def build(self, python_type, op): class _FoldedOps: """Mapping from a folded opcode to the top level constant that replaces it.""" - def __init__(self): + def __init__(self) -> None: self.folds = {} - def add(self, op): + def add(self, op) -> None: self.folds[id(op)] = op.folded def resolve(self, op): @@ -453,7 +455,7 @@ def union(params): return (tag, union(params)) -def from_literal(tup): +def from_literal(tup: _T0) -> Union[tuple, _T0]: """Convert from simple literal form to the more uniform typestruct.""" def expand(vals): diff --git a/pytype/context.py b/pytype/context.py index 89d029ac1..30582b25b 100644 --- a/pytype/context.py +++ b/pytype/context.py @@ -21,7 +21,7 @@ from pytype.typegraph import cfg from pytype.typegraph import cfg_utils -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) class Context: @@ -107,7 +107,7 @@ def __init__( # cache doesn't persist between runs. self.function_cache = {} - def matcher(self, node): + def matcher(self, node) -> matcher.AbstractMatcher: return matcher.AbstractMatcher(node, self) @contextlib.contextmanager @@ -153,7 +153,7 @@ def make_class(self, node, props): def check_annotation_type_mismatch( self, node, name, typ, value, stack, allow_none, details=None - ): + ) -> None: """Checks for a mismatch between a variable's annotation and value. Args: diff --git a/pytype/convert.py b/pytype/convert.py index 08f198045..d390ae255 100644 --- a/pytype/convert.py +++ b/pytype/convert.py @@ -3,7 +3,7 @@ import contextlib import logging import types -from typing import Any +from typing import TypeVar, Any import pycnite from pytype import datatypes @@ -27,7 +27,9 @@ from pytype.pytd import pytd_utils from pytype.typegraph import cfg -log = logging.getLogger(__name__) +_T1 = TypeVar("_T1") + +log: logging.Logger = logging.getLogger(__name__) _MAX_IMPORT_DEPTH = 12 @@ -155,7 +157,7 @@ def constant_name(self, constant_type): else: return constant_type.__name__ - def _type_to_name(self, t): + def _type_to_name(self, t) -> tuple[str, Any]: """Convert a type to its name.""" assert t.__class__ is type if t is types.FunctionType: @@ -184,7 +186,7 @@ def lookup_value(self, module, name, subst=None): subst = subst or datatypes.AliasingDict() return self.constant_to_value(pytd_cls, subst) - def tuple_to_value(self, content): + def tuple_to_value(self, content) -> abstract.Tuple: """Create a VM tuple from the given sequence.""" content = tuple(content) # content might be a generator value = abstract.Tuple(content, self.ctx) @@ -199,7 +201,7 @@ def build_bool(self, node, value=None): else: raise ValueError(f"Invalid bool value: {value!r}") - def build_concrete_value(self, value, typ): + def build_concrete_value(self, value, typ) -> abstract.ConcreteValue: typ = self.primitive_classes[typ] return abstract.ConcreteValue(value, typ, self.ctx) @@ -274,16 +276,16 @@ def build_tuple(self, node, content): """Create a VM tuple from the given sequence.""" return self.tuple_to_value(content).to_variable(node) - def make_typed_dict_builder(self): + def make_typed_dict_builder(self) -> typed_dict.TypedDictBuilder: """Make a typed dict builder.""" return typed_dict.TypedDictBuilder(self.ctx) - def make_typed_dict(self, name, pytd_cls): + def make_typed_dict(self, name, pytd_cls) -> typed_dict.TypedDictClass: """Make a typed dict from a pytd class.""" builder = typed_dict.TypedDictBuilder(self.ctx) return builder.make_class_from_pyi(name, pytd_cls) - def make_namedtuple_builder(self): + def make_namedtuple_builder(self) -> named_tuple.NamedTupleClassBuilder: """Make a namedtuple builder.""" return named_tuple.NamedTupleClassBuilder(self.ctx) @@ -292,7 +294,7 @@ def make_namedtuple(self, name, pytd_cls): builder = named_tuple.NamedTupleClassBuilder(self.ctx) return builder.make_class_from_pyi(name, pytd_cls) - def apply_dataclass_transform(self, cls_var, node): + def apply_dataclass_transform(self, cls_var, node: _T1) -> tuple[_T1, Any]: cls = abstract_utils.get_atomic_value(cls_var) # We need to propagate the metadata key since anything in the entire tree of # subclasses is a dataclass, even without a decorator. @@ -376,7 +378,7 @@ def _copy_type_parameters( else: return new_container - def widen_type(self, container): + def widen_type(self, container) -> abstract.BaseValue: """Widen a tuple to an iterable, or a dict to a mapping.""" if container.full_name == "builtins.tuple": return self._copy_type_parameters(container, "typing", "Iterable") @@ -605,7 +607,7 @@ def get_node(): need_node[0] = True return node - recursive = isinstance(pyval, pytd.LateType) and pyval.recursive + recursive = isinstance(pyval, pytd.LateType) and pyval.recursive # pytype: disable=attribute-error if recursive: context = self.ctx.allow_recursive_convert() else: @@ -630,7 +632,7 @@ def get_node(): # d = {"a": 1j} if recursive: annot = abstract.LateAnnotation( - pyval.name, self.ctx.vm.frames, self.ctx + pyval.name, self.ctx.vm.frames, self.ctx # pytype: disable=attribute-error ) annot.set_type(value) value = annot @@ -641,7 +643,7 @@ def _load_late_type(self, late_type): """Resolve a late type, possibly by loading a module.""" return self.ctx.loader.load_late_type(late_type) - def _create_module(self, ast): + def _create_module(self, ast) -> abstract.Module: if not ast: raise abstract_utils.ModuleLoadError() data = ( diff --git a/pytype/convert_structural.py b/pytype/convert_structural.py index 22b3ae4f1..a72525feb 100644 --- a/pytype/convert_structural.py +++ b/pytype/convert_structural.py @@ -2,7 +2,7 @@ import itertools import logging -from typing import AbstractSet +from typing import Any, Union, AbstractSet from pytype.pytd import booleq from pytype.pytd import escape @@ -13,7 +13,7 @@ from pytype.pytd import type_match from pytype.pytd import visitors -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) # How deep to nest type parameters # TODO(b/159041279): Currently, the solver only generates variables for depth 1. @@ -31,12 +31,14 @@ class FlawedQuery(Exception): # pylint: disable=g-bad-exception-name class TypeSolver: """Class for solving ~unknowns in type inference results.""" - def __init__(self, ast, builtins, protocols): + def __init__(self, ast, builtins, protocols) -> None: self.ast = ast self.builtins = builtins self.protocols = protocols - def match_unknown_against_protocol(self, matcher, solver, unknown, complete): + def match_unknown_against_protocol( + self, matcher, solver, unknown, complete + ) -> None: """Given an ~unknown, match it against a class. Args: @@ -67,7 +69,9 @@ def match_unknown_against_protocol(self, matcher, solver, unknown, complete): solver.register_variable(param.name) solver.implies(booleq.Eq(unknown.name, complete.name), implication) - def match_partial_against_complete(self, matcher, solver, partial, complete): + def match_partial_against_complete( + self, matcher, solver, partial, complete + ) -> None: """Match a partial class (call record) against a complete class. Args: @@ -93,7 +97,7 @@ def match_partial_against_complete(self, matcher, solver, partial, complete): raise FlawedQuery(f"{partial.name} can never be {complete.name}") solver.always_true(formula) - def match_call_record(self, matcher, solver, call_record, complete): + def match_call_record(self, matcher, solver, call_record, complete) -> None: """Match the record of a method call against the formal signature.""" assert is_partial(call_record) assert is_complete(complete) @@ -118,7 +122,7 @@ def match_call_record(self, matcher, solver, call_record, complete): ) solver.always_true(formula) - def solve(self): + def solve(self) -> dict: """Solve the equations generated from the pytd. Returns: @@ -210,7 +214,7 @@ def solve(self): return merged_solution -def solve(ast, builtins_pytd, protocols_pytd): +def solve(ast, builtins_pytd, protocols_pytd) -> tuple[Any, Any]: """Solve the unknowns in a pytd AST using the standard Python builtins. Args: @@ -231,7 +235,7 @@ class names and (2) a pytd.TypeDeclUnit of the complete classes in ast. ) -def extract_local(ast): +def extract_local(ast) -> pytd.TypeDeclUnit: """Extract all classes that are not unknowns of call records of builtins.""" return pytd.TypeDeclUnit( name=ast.name, @@ -243,7 +247,9 @@ def extract_local(ast): ) -def convert_string_type(string_type, unknown, mapping, global_lookup, depth=0): +def convert_string_type( + string_type, unknown, mapping, global_lookup, depth=0 +) -> Union[pytd.ClassType, pytd.GenericType, pytd.NamedType]: """Convert a string representing a type back to a pytd type.""" try: # Check whether this is a type declared in a pytd. diff --git a/pytype/datatypes.py b/pytype/datatypes.py index 3a51af38c..23c08ac10 100644 --- a/pytype/datatypes.py +++ b/pytype/datatypes.py @@ -3,10 +3,13 @@ import argparse import contextlib import itertools -from typing import TypeVar +from typing import Any, TypeVar import immutabledict +_TParserWrapper = TypeVar("_TParserWrapper", bound="ParserWrapper") + + _K = TypeVar("_K") _V = TypeVar("_V") @@ -48,14 +51,14 @@ class UnionFind: latest_id: the maximal allocated id. """ - def __init__(self): + def __init__(self) -> None: self.name2id = {} self.parent = [] self.rank = [] self.id2name = [] self.latest_id = 0 - def merge_from(self, uf): + def merge_from(self, uf) -> None: """Merge a UnionFind into the current one.""" for i, name in enumerate(uf.id2name): self.merge(name, uf.id2name[uf.parent[i]]) @@ -72,7 +75,7 @@ def merge(self, name1, name2): self._merge(key1, key2) return self.find_by_name(name1) - def _get_or_add_id(self, name): + def _get_or_add_id(self, name) -> int: if name not in self.name2id: self.name2id[name] = self.latest_id self.parent.append(self.latest_id) @@ -91,7 +94,7 @@ def _find(self, key): self.parent[key] = res return res - def _merge(self, k1, k2): + def _merge(self, k1, k2) -> None: """Merge two components.""" assert self.latest_id > k1 and self.latest_id > k2 s1 = self._find(k1) @@ -105,7 +108,7 @@ def _merge(self, k1, k2): self.parent[s1] = s2 self.rank[s2] += 1 - def __repr__(self): + def __repr__(self) -> str: comps = [] used = set() for x in self.id2name: @@ -122,7 +125,7 @@ def __repr__(self): class AccessTrackingDict(dict[_K, _V]): """A dict that tracks access of its original items.""" - def __init__(self, d=()): + def __init__(self, d=()) -> None: super().__init__(d) self.accessed_subset = {} @@ -132,18 +135,18 @@ def __getitem__(self, k): self.accessed_subset[k] = v return v - def __setitem__(self, k, v): + def __setitem__(self, k, v) -> None: if k in self: _ = self[k] # If the key is new, we don't track it. return super().__setitem__(k, v) - def __delitem__(self, k): + def __delitem__(self, k) -> None: if k in self: _ = self[k] return super().__delitem__(k) - def update(self, *args, **kwargs): + def update(self, *args, **kwargs) -> None: super().update(*args, **kwargs) for d in args: if isinstance(d, AccessTrackingDict): @@ -180,7 +183,7 @@ def data(self): class AliasingDictConflictError(Exception): - def __init__(self, existing_name): + def __init__(self, existing_name) -> None: super().__init__() self.existing_name = existing_name @@ -214,25 +217,25 @@ def __init__(self, *args, aliases: UnionFind | None = None, **kwargs): def aliases(self): return self._aliases - def copy(self, *args, aliases=None, **kwargs): + def copy(self, *args, aliases=None, **kwargs) -> "AliasingDict": return self.__class__(self, *args, aliases=aliases, **kwargs) def same_name(self, name1, name2): return self.aliases.find_by_name(name1) == self.aliases.find_by_name(name2) - def __contains__(self, name): + def __contains__(self, name) -> bool: return super().__contains__(self.aliases.find_by_name(name)) - def __setitem__(self, name, var): + def __setitem__(self, name, var) -> None: super().__setitem__(self.aliases.find_by_name(name), var) def __getitem__(self, name): return super().__getitem__(self.aliases.find_by_name(name)) - def __repr__(self): + def __repr__(self) -> str: return f"{super().__repr__()!r}, _alias={repr(self.aliases)!r}" - def __hash__(self): + def __hash__(self) -> int: return hash(frozenset(self.items())) def get(self, name, default=None): @@ -282,7 +285,7 @@ def viewkeys(self): def viewvalues(self): raise NotImplementedError() - def merge_from(self, lam_dict, op): + def merge_from(self, lam_dict, op) -> None: """Merge the other `AliasingDict` into current class. Args: @@ -305,7 +308,7 @@ def merge_from(self, lam_dict, op): ): self.add_alias(cur_name, parent_name, op) - def _merge(self, name1, name2, op): + def _merge(self, name1, name2, op) -> None: name1 = self.aliases.find_by_name(name1) name2 = self.aliases.find_by_name(name2) assert name1 != name2 @@ -314,14 +317,14 @@ def _merge(self, name1, name2, op): root = self.aliases.merge(name1, name2) self._copy_item(name1, root) - def _copy_item(self, src, tgt): + def _copy_item(self, src, tgt) -> None: """Assign the dict `src` value to `tgt`.""" if src == tgt: return self[tgt] = dict.__getitem__(self, src) dict.__delitem__(self, src) - def add_alias(self, alias, name, op=None): + def add_alias(self, alias, name, op=None) -> None: """Alias 'alias' to 'name'. After aliasing, we will think `alias` and `name`, they represent the same @@ -355,7 +358,7 @@ class HashableDict(AliasingDict[_K, _V]): have been overwritten to throw an exception. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._hash = hash(frozenset(self.items())) @@ -380,7 +383,7 @@ def __setitem__(self, name, var): def __delitem__(self, y): raise TypeError() - def __hash__(self): + def __hash__(self) -> int: return self._hash @@ -391,13 +394,13 @@ class AliasingMonitorDict(AliasingDict[_K, _V], MonitorDict[_K, _V]): class Box: """A mutable shared value.""" - def __init__(self, value=None): + def __init__(self, value=None) -> None: self._value = value def __get__(self, unused_obj, unused_objname): return self._value - def __set__(self, unused_obj, value): + def __set__(self, unused_obj, value) -> None: self._value = value @@ -405,9 +408,9 @@ class ParserWrapper: """Wrapper that adds arguments to a parser while recording them.""" # This needs to be a classvar so that it is shared by subgroups - _only = Box(None) + _only: Any = Box(None) - def __init__(self, parser, actions=None): + def __init__(self, parser, actions=None) -> None: self.parser = parser self.actions = {} if actions is None else actions @@ -421,7 +424,7 @@ def add_only(self, args): finally: self._only = only - def add_argument(self, *args, **kwargs): + def add_argument(self, *args, **kwargs) -> None: if self._only and not any(arg in self._only for arg in args): return try: @@ -432,7 +435,9 @@ def add_argument(self, *args, **kwargs): else: self.actions[action.dest] = action - def add_argument_group(self, *args, **kwargs): + def add_argument_group( + self: _TParserWrapper, *args, **kwargs + ) -> _TParserWrapper: group = self.parser.add_argument_group(*args, **kwargs) wrapped_group = self.__class__(group, actions=self.actions) return wrapped_group diff --git a/pytype/debug.py b/pytype/debug.py index 1912f283c..bb389db0b 100644 --- a/pytype/debug.py +++ b/pytype/debug.py @@ -7,13 +7,18 @@ import logging import re import traceback +from typing import Any, Callable, TypeVar from pytype import utils from pytype.typegraph import cfg_utils import tabulate +_T1 = TypeVar("_T1") -def _ascii_tree(out, node, p1, p2, seen, get_children, get_description=None): + +def _ascii_tree( + out, node, p1, p2, seen, get_children, get_description=None +) -> None: """Draw a graph, starting at a given position. Args: @@ -46,7 +51,7 @@ def _ascii_tree(out, node, p1, p2, seen, get_children, get_description=None): ) -def ascii_tree(node, get_children, get_description=None): +def ascii_tree(node, get_children, get_description=None) -> str: """Draw a graph, starting at a given position. Args: @@ -62,7 +67,7 @@ def ascii_tree(node, get_children, get_description=None): return out.getvalue() -def prettyprint_binding(binding, indent_level=0): +def prettyprint_binding(binding, indent_level=0) -> str: """Pretty print a binding with variable id and data.""" indent = " " * indent_level if not binding: @@ -70,7 +75,7 @@ def prettyprint_binding(binding, indent_level=0): return "%s" % (indent, binding.variable.id, binding.data) -def prettyprint_binding_set(binding_set, indent_level=0, label=""): +def prettyprint_binding_set(binding_set, indent_level=0, label="") -> str: """Pretty print a set of bindings, with optional label.""" indent = " " * indent_level start = f"{indent}{label}: {{" @@ -83,7 +88,7 @@ def prettyprint_binding_set(binding_set, indent_level=0, label=""): ) -def prettyprint_binding_nested(binding, indent_level=0): +def prettyprint_binding_nested(binding, indent_level=0) -> str: """Pretty print a binding and its recursive contents.""" indent = " " * indent_level if indent_level > 32: @@ -107,7 +112,7 @@ def prettyprint_binding_nested(binding, indent_level=0): return s -def prettyprint_cfg_node(node, decorate_after_node=0, full=False): +def prettyprint_cfg_node(node, decorate_after_node=0, full=False) -> str: """A reasonably compact representation of all the bindings at a node. Args: @@ -155,7 +160,7 @@ def prettyprint_cfg_tree( return ascii_tree(root, get_children=children, get_description=desc) -def _pretty_variable(var): +def _pretty_variable(var) -> str: """Return a pretty printed string for a Variable.""" lines = [] single_value = len(var.bindings) == 1 @@ -189,7 +194,7 @@ def _pretty_variable(var): return "\n".join(lines) -def program_to_text(program): +def program_to_text(program) -> str: """Generate a text (CFG nodes + assignments) version of a program. For debugging only. @@ -226,7 +231,7 @@ def label(node): return s.getvalue() -def root_cause(binding, node, seen=()): +def root_cause(binding, node, seen=()) -> tuple[Any, Any]: """Tries to determine why a binding isn't possible at a node. This tries to find the innermost source that's still impossible. It only works @@ -260,7 +265,7 @@ def root_cause(binding, node, seen=()): return None, None -def stack_trace(indent_level=0, limit=100): +def stack_trace(indent_level=0, limit=100) -> str: indent = " " * indent_level stack = [ frame @@ -272,7 +277,7 @@ def stack_trace(indent_level=0, limit=100): return "\n ".join(tb) -def _setup_tabulate(): +def _setup_tabulate() -> None: """Customise tabulate.""" tabulate.PRESERVE_WHITESPACE = True tabulate.MIN_PADDING = 0 @@ -291,7 +296,7 @@ def _setup_tabulate(): # pytype: enable=module-attr -def show_ordered_code(code, extra_col=None): +def show_ordered_code(code, extra_col=None) -> None: """Print out the block structure of an OrderedCode object as a table. Args: @@ -350,12 +355,12 @@ def show_ordered_code(code, extra_col=None): # Tracing logger -def tracer(name=None): +def tracer(name=None) -> logging.Logger: name = f"trace.{name}" if name else "trace" return logging.getLogger(name) -def set_trace_level(level): +def set_trace_level(level) -> None: logging.getLogger("trace").setLevel(level) @@ -370,7 +375,7 @@ def tracing(level=logging.DEBUG): log.setLevel(current_level) -def trace(name, *trace_args): +def trace(name, *trace_args) -> Callable[[Any], Any]: """Record args and return value for a function call. The trace is of the form @@ -416,7 +421,7 @@ def wrapper(*args, **kwargs): return decorator -def show(x): +def show(x) -> str: """Pretty print values for debugging.""" typename = x.__class__.__name__ if typename == "Variable": diff --git a/pytype/directors/annotations.py b/pytype/directors/annotations.py index 034f1493d..6fb8ae407 100644 --- a/pytype/directors/annotations.py +++ b/pytype/directors/annotations.py @@ -14,14 +14,14 @@ class VariableAnnotation: class VariableAnnotations: """Store variable annotations and typecomments for a program.""" - def __init__(self): + def __init__(self) -> None: self.variable_annotations: dict[int, VariableAnnotation] = {} self.type_comments: dict[int, str] = {} - def add_annotation(self, line: int, name: str, annotation: str): + def add_annotation(self, line: int, name: str, annotation: str) -> None: self.variable_annotations[line] = VariableAnnotation(name, annotation) - def add_type_comment(self, line: int, annotation: str): + def add_type_comment(self, line: int, annotation: str) -> None: self.type_comments[line] = annotation @property diff --git a/pytype/directors/directors.py b/pytype/directors/directors.py index 67e7f6281..85919bdfd 100644 --- a/pytype/directors/directors.py +++ b/pytype/directors/directors.py @@ -4,26 +4,27 @@ import collections import logging import sys - +from typing import Any from pytype import config from pytype.directors import annotations - from pytype.directors import parser + + # pylint: enable=g-import-not-at-top -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) -SkipFileError = parser.SkipFileError +SkipFileError: type[parser.SkipFileError] = parser.SkipFileError parse_src = parser.parse_src _ALL_ERRORS = "*" # Wildcard for disabling all errors. -_ALLOWED_FEATURES = frozenset(x.flag for x in config.FEATURE_FLAGS) +_ALLOWED_FEATURES: frozenset = frozenset(x.flag for x in config.FEATURE_FLAGS) -_PRAGMAS = frozenset({"cache-return"}) +_PRAGMAS: frozenset[str] = frozenset({"cache-return"}) -_FUNCTION_CALL_ERRORS = frozenset(( +_FUNCTION_CALL_ERRORS: frozenset[str] = frozenset(( # A function call may implicitly access a magic method attribute. "attribute-error", "duplicate-keyword", @@ -37,7 +38,7 @@ "unsupported-operands", )) -_ALL_ADJUSTABLE_ERRORS = _FUNCTION_CALL_ERRORS.union(( +_ALL_ADJUSTABLE_ERRORS: frozenset[str] = _FUNCTION_CALL_ERRORS.union(( "annotation-type-mismatch", "bad-return-type", "bad-yield-annotation", @@ -60,7 +61,7 @@ class _LineSet: their own line apply until countered by the opposing directive. """ - def __init__(self): + def __init__(self) -> None: # Map of line->bool for specific lines, takes precedence over _transitions. self._lines = {} # A sorted list of the lines at which the range state changes @@ -75,11 +76,11 @@ def __init__(self): def lines(self): return self._lines - def set_line(self, line, membership): + def set_line(self, line, membership) -> None: """Set whether a given line is a member of the set.""" self._lines[line] = membership - def start_range(self, line, membership): + def start_range(self, line, membership) -> None: """Start a range of lines that are either included/excluded from the set. Args: @@ -108,7 +109,7 @@ def start_range(self, line, membership): # Normal case - add a transition at this line. self._transitions.append(line) - def __contains__(self, line): + def __contains__(self, line) -> bool: """Return if a line is a member of the set.""" # First check for an entry in _lines. specific = self._lines.get(line) @@ -129,18 +130,18 @@ def get_disable_after(self, line): class _BlockRanges: """A collection of possibly nested start..end ranges from AST nodes.""" - def __init__(self, start_to_end_mapping): + def __init__(self, start_to_end_mapping) -> None: self._starts = sorted(start_to_end_mapping) self._start_to_end = start_to_end_mapping self._end_to_start = {v: k for k, v in start_to_end_mapping.items()} - def has_start(self, line): + def has_start(self, line) -> bool: return line in self._start_to_end - def has_end(self, line): + def has_end(self, line) -> bool: return line in self._end_to_start - def find_outermost(self, line): + def find_outermost(self, line) -> tuple[Any, Any]: """Find the outermost interval containing line.""" i = bisect.bisect_left(self._starts, line) num_intervals = len(self._starts) @@ -161,7 +162,7 @@ def find_outermost(self, line): return start, end return None, None - def adjust_end(self, old_end, new_end): + def adjust_end(self, old_end, new_end) -> None: start = self._end_to_start[old_end] self._start_to_end[start] = new_end del self._end_to_start[old_end] @@ -171,7 +172,7 @@ def adjust_end(self, old_end, new_end): class Director: """Holds all of the directive information for a source file.""" - def __init__(self, src_tree, errorlog, filename, disable): + def __init__(self, src_tree, errorlog, filename, disable) -> None: """Create a Director for a source file. Args: @@ -237,10 +238,10 @@ def decorators(self): def decorated_functions(self): return self._decorated_functions - def has_pragma(self, pragma, line): + def has_pragma(self, pragma, line) -> bool: return pragma in self._pragmas and line in self._pragmas[pragma] - def _parse_src_tree(self, src_tree): + def _parse_src_tree(self, src_tree) -> None: """Parse a source file, extracting directives from comments.""" visitor = parser.visit_src_tree(src_tree) # TODO(rechen): This check can be removed once parser_libcst is gone. @@ -353,7 +354,7 @@ def _process_pytype( else: raise _DirectiveError(f"Unknown pytype directive: '{command}'") - def _process_features(self, features: set[str]): + def _process_features(self, features: set[str]) -> None: invalid = features - _ALLOWED_FEATURES if invalid: raise _DirectiveError(f"Unknown pytype features: {','.join(invalid)}") @@ -427,7 +428,7 @@ def _adjust_line_number_for_pytype_directive( return line return line_range.start_line - def filter_error(self, error): + def filter_error(self, error) -> bool: """Return whether the error should be logged. This method is suitable for use as an error filter. diff --git a/pytype/directors/parser.py b/pytype/directors/parser.py index 6ed97eedb..156681fe1 100644 --- a/pytype/directors/parser.py +++ b/pytype/directors/parser.py @@ -2,21 +2,24 @@ import ast import collections -from collections.abc import Mapping, Sequence +from collections.abc import Generator, Mapping, Sequence import dataclasses import io import logging import re import tokenize +from typing import Any, TypeVar from pytype.ast import visitor -log = logging.getLogger(__name__) +_TLineRange = TypeVar("_TLineRange", bound="LineRange") + +log: logging.Logger = logging.getLogger(__name__) # Also supports mypy-style ignore[code, ...] syntax, treated as regular ignores. -IGNORE_RE = re.compile(r"^ignore(\[.+\])?$") +IGNORE_RE: re.Pattern = re.compile(r"^ignore(\[.+\])?$") -_DIRECTIVE_RE = re.compile(r"#\s*(pytype|type)\s*:\s?([^#]*)") +_DIRECTIVE_RE: re.Pattern = re.compile(r"#\s*(pytype|type)\s*:\s?([^#]*)") class SkipFileError(Exception): @@ -29,7 +32,7 @@ class LineRange: end_line: int @classmethod - def from_node(cls, node): + def from_node(cls: type[_TLineRange], node) -> _TLineRange: return cls(node.lineno, node.end_lineno) def __contains__(self, line): @@ -80,34 +83,34 @@ class _SourceTree: class _BlockReturns: """Tracks return statements in with/try blocks.""" - def __init__(self): + def __init__(self) -> None: self._block_ranges = [] self._returns = [] self._block_returns = {} self._final = False - def add_block(self, node): + def add_block(self, node) -> None: line_range = LineRange.from_node(node) self._block_ranges.append(line_range) - def add_return(self, node): + def add_return(self, node) -> None: self._returns.append(node.lineno) - def finalize(self): + def finalize(self) -> None: for br in self._block_ranges: self._block_returns[br.start_line] = sorted( r for r in self._returns if r in br ) self._final = True - def all_returns(self): + def all_returns(self) -> set: return set(self._returns) def __iter__(self): assert self._final return iter(self._block_returns.items()) - def __repr__(self): + def __repr__(self) -> str: return f""" Blocks: {self._block_ranges} Returns: {self._returns} @@ -134,10 +137,10 @@ class _Match: class _Matches: """Tracks branches of match statements.""" - def __init__(self): + def __init__(self) -> None: self.matches = [] - def add_match(self, start, end, cases): + def add_match(self, start, end, cases) -> None: self.matches.append(_Match(start, end, cases)) @@ -161,7 +164,7 @@ class _ParseVisitor(visitor.BaseVisitor): appears, if any. """ - def __init__(self, raw_structured_comments): + def __init__(self, raw_structured_comments) -> None: super().__init__(ast) self._raw_structured_comments = raw_structured_comments # We initialize structured_comment_groups with single-line groups for all @@ -227,7 +230,7 @@ def _add_structured_comment_group(self, start_line, end_line, cls=LineRange): self.structured_comment_groups.move_to_end(k) return new_group - def _process_structured_comments(self, line_range, cls=LineRange): + def _process_structured_comments(self, line_range, cls=LineRange) -> None: def should_add(comment, group): # Don't add the comment more than once. @@ -259,19 +262,19 @@ def should_add(comment, group): if cls is not LineRange: group.extend(c for c in structured_comments if should_add(c, group)) - def leave_Module(self, node): + def leave_Module(self, node) -> None: self.block_returns.finalize() - def visit_Call(self, node): + def visit_Call(self, node) -> None: self._process_structured_comments(LineRange.from_node(node), cls=Call) - def visit_Compare(self, node): + def visit_Compare(self, node) -> None: self._process_structured_comments(LineRange.from_node(node), cls=Call) - def visit_Subscript(self, node): + def visit_Subscript(self, node) -> None: self._process_structured_comments(LineRange.from_node(node), cls=Call) - def visit_AnnAssign(self, node): + def visit_AnnAssign(self, node) -> None: if not node.value: # vm.py preprocesses the source code so that all annotations in function # bodies have values. So the only annotations without values are module- @@ -288,40 +291,40 @@ def visit_AnnAssign(self, node): ) self._process_structured_comments(LineRange.from_node(node)) - def _visit_try(self, node): + def _visit_try(self, node) -> None: for handler in node.handlers: if handler.type: self._process_structured_comments(LineRange.from_node(handler.type)) - def visit_Try(self, node): + def visit_Try(self, node) -> None: self._visit_try(node) - def visit_TryStar(self, node): + def visit_TryStar(self, node) -> None: self._visit_try(node) - def _visit_with(self, node): + def _visit_with(self, node) -> None: item = node.items[-1] end_lineno = (item.optional_vars or item.context_expr).end_lineno if self.block_depth == 1: self.block_returns.add_block(node) self._process_structured_comments(LineRange(node.lineno, end_lineno)) - def enter_With(self, node): + def enter_With(self, node) -> None: self.block_depth += 1 - def leave_With(self, node): + def leave_With(self, node) -> None: self.block_depth -= 1 - def enter_AsyncWith(self, node): + def enter_AsyncWith(self, node) -> None: self.block_depth += 1 - def leave_AsyncWith(self, node): + def leave_AsyncWith(self, node) -> None: self.block_depth -= 1 - def visit_With(self, node): + def visit_With(self, node) -> None: self._visit_with(node) - def visit_AsyncWith(self, node): + def visit_AsyncWith(self, node) -> None: self._visit_with(node) def _is_underscore(self, node): @@ -333,7 +336,7 @@ def _is_underscore(self, node): else: return self._is_underscore(node.pattern) - def visit_Match(self, node): + def visit_Match(self, node) -> None: start = node.lineno end = node.end_lineno cases = [] @@ -353,7 +356,7 @@ def visit_Match(self, node): cases.append(match_case) self.matches.add_match(start, end, cases) - def generic_visit(self, node): + def generic_visit(self, node) -> None: if not isinstance(node, ast.stmt): return if hasattr(node, "body"): @@ -372,11 +375,11 @@ def generic_visit(self, node): else: self._process_structured_comments(LineRange.from_node(node)) - def visit_Return(self, node): + def visit_Return(self, node) -> None: self.block_returns.add_return(node) self._process_structured_comments(LineRange.from_node(node)) - def _visit_decorators(self, node): + def _visit_decorators(self, node) -> None: if not node.decorator_list: return for dec in node.decorator_list: @@ -384,15 +387,15 @@ def _visit_decorators(self, node): dec_base = dec.func if isinstance(dec, ast.Call) else dec self.decorators[node.lineno].append((dec.lineno, ast.unparse(dec_base))) - def _visit_def(self, node): + def _visit_def(self, node) -> None: self._visit_decorators(node) if not self.defs_start or node.lineno < self.defs_start: self.defs_start = node.lineno - def visit_ClassDef(self, node): + def visit_ClassDef(self, node) -> None: self._visit_def(node) - def _visit_function_def(self, node): + def _visit_function_def(self, node) -> None: # A function signature's line range should start at the beginning of the # signature and end at the final colon. Since we can't get the lineno of # the final colon from the ast, we do our best to approximate it. @@ -438,14 +441,14 @@ def _visit_function_def(self, node): _ParamAnnotations(node.lineno, node.end_lineno, node.name, annots) ) - def visit_FunctionDef(self, node): + def visit_FunctionDef(self, node) -> None: self._visit_function_def(node) - def visit_AsyncFunctionDef(self, node): + def visit_AsyncFunctionDef(self, node) -> None: self._visit_function_def(node) -def _process_comments(src): +def _process_comments(src) -> collections.defaultdict[int, Any]: structured_comments = collections.defaultdict(list) f = io.StringIO(src) for token in tokenize.generate_tokens(f.readline): @@ -457,7 +460,9 @@ def _process_comments(src): return structured_comments -def _process_comment(line, lineno, col): +def _process_comment( + line, lineno, col +) -> Generator[_StructuredComment, Any, None]: """Process a single comment.""" matches = list(_DIRECTIVE_RE.finditer(line[col:])) if not matches: @@ -477,14 +482,14 @@ def _process_comment(line, lineno, col): yield _StructuredComment(lineno, tool, data, open_ended) -def parse_src(src: str, python_version: tuple[int, int]): +def parse_src(src: str, python_version: tuple[int, int]) -> _SourceTree: """Parses a string of source code into an ast.""" return _SourceTree( ast.parse(src, feature_version=python_version[1]), _process_comments(src) ) # pylint: disable=unexpected-keyword-arg -def visit_src_tree(src_tree): +def visit_src_tree(src_tree) -> _ParseVisitor: parse_visitor = _ParseVisitor(src_tree.structured_comments) parse_visitor.visit(src_tree.ast) return parse_visitor diff --git a/pytype/errors/error_printer.py b/pytype/errors/error_printer.py index 3ddc1bc46..168a1db59 100644 --- a/pytype/errors/error_printer.py +++ b/pytype/errors/error_printer.py @@ -1,8 +1,10 @@ """A printer for human-readable output of error messages.""" import collections +from collections.abc import Generator import dataclasses import enum +from typing import Any from pytype import matcher from pytype import pretty_printer_base @@ -51,7 +53,7 @@ def __init__( self.bad_call = bad_call self._pp = pp - def _iter_sig(self): + def _iter_sig(self) -> Generator[tuple[str, Any], Any, None]: """Iterate through a Signature object. Focus on a bad parameter.""" sig = self.bad_call.sig for name in sig.posonly_params: @@ -69,7 +71,7 @@ def _iter_sig(self): if sig.kwargs_name is not None: yield "**", sig.kwargs_name - def _iter_expected(self): + def _iter_expected(self) -> Generator[tuple[Any, Any, str], Any, None]: """Yield the prefix, name and type information for expected parameters.""" bad_param = self.bad_call.bad_param sig = self.bad_call.sig @@ -80,7 +82,7 @@ def _iter_expected(self): suffix = ": " + type_str + suffix yield prefix, name, suffix - def _iter_actual(self, literal): + def _iter_actual(self, literal) -> Generator[tuple[str, str, str], Any, None]: """Yield the prefix, name and type information for actual parameters.""" # We want to display the passed_args in the order they're defined in the # signature, unless there are starargs or starstarargs. @@ -106,7 +108,7 @@ def key_f(arg): suffix = "" yield "", name, suffix - def _print_args(self, arg_iter): + def _print_args(self, arg_iter) -> str: """Pretty-print a list of arguments. Focus on a bad parameter.""" # (foo, bar, broken : type, ...) bad_param = self.bad_call.bad_param @@ -125,7 +127,7 @@ def _print_args(self, arg_iter): printed_params.append(prefix + name) return ", ".join(printed_params) - def print_call_details(self): + def print_call_details(self) -> BadCall: bad_param = self.bad_call.bad_param expected = self._print_args(self._iter_expected()) literal = "Literal[" in expected diff --git a/pytype/errors/error_types.py b/pytype/errors/error_types.py index 127086f92..b074f5b28 100644 --- a/pytype/errors/error_types.py +++ b/pytype/errors/error_types.py @@ -2,7 +2,7 @@ from collections.abc import Sequence import dataclasses -from typing import Optional +from typing import Any, Optional from pytype.types import types @@ -10,16 +10,16 @@ class ReturnValueMixin: """Mixin for exceptions that hold a return node and variable.""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.return_node = None self.return_variable = None - def set_return(self, node, var): + def set_return(self, node, var) -> None: self.return_node = node self.return_variable = var - def get_return(self, state): + def get_return(self, state) -> tuple[Any, Any]: return state.change_cfg_node(self.return_node), self.return_variable @@ -40,21 +40,21 @@ class BadType: class FailedFunctionCall(Exception, ReturnValueMixin): """Exception for failed function calls.""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.name = "" - def __gt__(self, other): + def __gt__(self, other) -> bool: return other is None - def __le__(self, other): + def __le__(self, other) -> bool: return not self.__gt__(other) class NotCallable(FailedFunctionCall): """For objects that don't have __call__.""" - def __init__(self, obj): + def __init__(self, obj) -> None: super().__init__() self.obj = obj @@ -62,7 +62,7 @@ def __init__(self, obj): class UndefinedParameterError(FailedFunctionCall): """Function called with an undefined variable.""" - def __init__(self, name): + def __init__(self, name) -> None: super().__init__() self.name = name @@ -70,14 +70,14 @@ def __init__(self, name): class DictKeyMissing(Exception, ReturnValueMixin): """When retrieving a key that does not exist in a dict.""" - def __init__(self, name): + def __init__(self, name) -> None: super().__init__() self.name = name - def __gt__(self, other): + def __gt__(self, other) -> bool: return other is None - def __le__(self, other): + def __le__(self, other) -> bool: return not self.__gt__(other) @@ -91,7 +91,7 @@ class BadCall: class InvalidParameters(FailedFunctionCall): """Exception for functions called with an incorrect parameter combination.""" - def __init__(self, sig, passed_args, ctx, bad_param=None): + def __init__(self, sig, passed_args, ctx, bad_param=None) -> None: super().__init__() self.name = sig.name passed_args = [ @@ -106,7 +106,7 @@ def __init__(self, sig, passed_args, ctx, bad_param=None): class WrongArgTypes(InvalidParameters): """For functions that were called with the wrong types.""" - def __init__(self, sig, passed_args, ctx, bad_param): + def __init__(self, sig, passed_args, ctx, bad_param) -> None: if not sig.has_param(bad_param.name): sig = sig.insert_varargs_and_kwargs( name for name, *_ in sig.iter_args(passed_args) @@ -129,7 +129,7 @@ def starcount(err): return starcount(self) < starcount(other) - def __le__(self, other): + def __le__(self, other) -> bool: return not self.__gt__(other) @@ -140,7 +140,7 @@ class WrongArgCount(InvalidParameters): class WrongKeywordArgs(InvalidParameters): """E.g. an arg "x" is passed to a function that doesn't have an "x" param.""" - def __init__(self, sig, passed_args, ctx, extra_keywords): + def __init__(self, sig, passed_args, ctx, extra_keywords) -> None: super().__init__(sig, passed_args, ctx) self.extra_keywords = tuple(extra_keywords) @@ -148,7 +148,7 @@ def __init__(self, sig, passed_args, ctx, extra_keywords): class DuplicateKeyword(InvalidParameters): """E.g. an arg "x" is passed to a function as both a posarg and a kwarg.""" - def __init__(self, sig, passed_args, ctx, duplicate): + def __init__(self, sig, passed_args, ctx, duplicate) -> None: super().__init__(sig, passed_args, ctx) self.duplicate = duplicate @@ -156,7 +156,7 @@ def __init__(self, sig, passed_args, ctx, duplicate): class MissingParameter(InvalidParameters): """E.g. a function requires parameter 'x' but 'x' isn't passed.""" - def __init__(self, sig, passed_args, ctx, missing_parameter): + def __init__(self, sig, passed_args, ctx, missing_parameter) -> None: super().__init__(sig, passed_args, ctx) self.missing_parameter = missing_parameter @@ -178,7 +178,7 @@ def __init__(self, typed_dict: types.BaseValue, key: str | None): class MatchError(Exception): - def __init__(self, bad_type: BadType, *args, **kwargs): + def __init__(self, bad_type: BadType, *args, **kwargs) -> None: self.bad_type = bad_type super().__init__(bad_type, *args, **kwargs) @@ -186,7 +186,7 @@ def __init__(self, bad_type: BadType, *args, **kwargs): class NonIterableStrError(Exception): """Error for matching `str` against `Iterable[str]`/`Sequence[str]`/etc.""" - def __init__(self, left_type, other_type): + def __init__(self, left_type, other_type) -> None: super().__init__() self.left_type = left_type self.other_type = other_type @@ -194,7 +194,7 @@ def __init__(self, left_type, other_type): class ProtocolError(Exception): - def __init__(self, left_type, other_type): + def __init__(self, left_type, other_type) -> None: super().__init__() self.left_type = left_type self.other_type = other_type @@ -202,14 +202,16 @@ def __init__(self, left_type, other_type): class ProtocolMissingAttributesError(ProtocolError): - def __init__(self, left_type, other_type, missing): + def __init__(self, left_type, other_type, missing) -> None: super().__init__(left_type, other_type) self.missing = missing class ProtocolTypeError(ProtocolError): - def __init__(self, left_type, other_type, attribute, actual, expected): + def __init__( + self, left_type, other_type, attribute, actual, expected + ) -> None: super().__init__(left_type, other_type) self.attribute_name = attribute self.actual_type = actual @@ -218,7 +220,7 @@ def __init__(self, left_type, other_type, attribute, actual, expected): class TypedDictError(Exception): - def __init__(self, bad, extra, missing): + def __init__(self, bad, extra, missing) -> None: super().__init__() self.bad = bad self.missing = missing diff --git a/pytype/errors/errors.py b/pytype/errors/errors.py index d6277eb68..e813f855b 100644 --- a/pytype/errors/errors.py +++ b/pytype/errors/errors.py @@ -1,12 +1,13 @@ """Code and data structures for storing and displaying errors.""" -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence import contextlib import csv import io import logging import sys -from typing import IO, TypeVar +from typing import Any, TypeVar +from typing import IO from pytype import debug from pytype import pretty_printer_base @@ -16,8 +17,11 @@ from pytype.pytd import slots from pytype.types import types + +_TError = TypeVar("_TError", bound="Error") + # Usually we call the logger "log" but that name is used quite often here. -_log = logging.getLogger(__name__) +_log: logging.Logger = logging.getLogger(__name__) # "Error level" enum for distinguishing between warnings and errors: @@ -25,10 +29,10 @@ SEVERITY_ERROR = 2 # The set of known error names. -_ERROR_NAMES = set() +_ERROR_NAMES: set[str] = set() # The current error name, managed by the error_name decorator. -_CURRENT_ERROR_NAME = utils.DynamicVar() +_CURRENT_ERROR_NAME: utils.DynamicVar = utils.DynamicVar() # Max number of calls in the traceback string. MAX_TRACEBACK_LENGTH = 3 @@ -40,7 +44,7 @@ TRACEBACK_MARKER = "Called from (traceback):" # Symbol representing an elided portion of the stack. -_ELLIPSIS = object() +_ELLIPSIS: Any = object() _FuncT = TypeVar("_FuncT", bound=Callable) @@ -79,7 +83,7 @@ def _maybe_truncate_traceback(traceback): return traceback -def _make_traceback_str(frames): +def _make_traceback_str(frames) -> str | None: """Turn a stack of frames into a traceback string.""" if len(frames) < 2 or ( frames[-1].f_code and not frames[-1].f_code.get_arg_count() @@ -106,7 +110,7 @@ def _make_traceback_str(frames): return TRACEBACK_MARKER + "\n " + "\n ".join(traceback) -def _dedup_opcodes(stack): +def _dedup_opcodes(stack) -> list: """Dedup the opcodes in a stack of frames.""" deduped_stack = [] if len(stack) > 1: @@ -125,7 +129,7 @@ def _dedup_opcodes(stack): return deduped_stack -def _compare_traceback_strings(left, right): +def _compare_traceback_strings(left, right) -> int | None: """Try to compare two traceback strings. Two traceback strings are comparable if they are equal, or if one ends with @@ -156,7 +160,7 @@ def _compare_traceback_strings(left, right): return None -def _function_name(name, capitalize=False): +def _function_name(name, capitalize=False) -> str: builtin_prefix = "builtins." if name.startswith(builtin_prefix): ret = f"built-in function {name[len(builtin_prefix):]}" @@ -171,12 +175,12 @@ def _function_name(name, capitalize=False): class CheckPoint: """Represents a position in an error log.""" - def __init__(self, errors): + def __init__(self, errors) -> None: self._errorlog_errors = errors self._position = len(errors) self.errors = None - def revert(self): + def revert(self) -> None: self.errors = self._errorlog_errors[self._position :] self._errorlog_errors[:] = self._errorlog_errors[: self._position] @@ -222,7 +226,7 @@ def __init__( keyword_context=None, bad_call=None, opcode_name=None, - ): + ) -> None: name = _CURRENT_ERROR_NAME.get() assert ( name @@ -248,7 +252,9 @@ def __init__( self._opcode_name = opcode_name @classmethod - def with_stack(cls, stack, severity, message, **kwargs): + def with_stack( + cls: type[_TError], stack, severity, message, **kwargs + ) -> _TError: """Return an error using a stack for position information. Args: @@ -280,7 +286,9 @@ def with_stack(cls, stack, severity, message, **kwargs): ) @classmethod - def for_test(cls, severity, message, name, **kwargs): + def for_test( + cls: type[_TError], severity, message, name, **kwargs + ) -> _TError: """Create an _Error with the specified name, for use in tests.""" with _CURRENT_ERROR_NAME.bind(name): return cls(severity, message, **kwargs) @@ -409,10 +417,10 @@ def _visualize_failed_lines(self) -> str: ) return "".join(concat_code_with_red_lines) - def get_unique_representation(self): + def get_unique_representation(self) -> tuple[Any, Any, Any, Any]: return (self._position(), self._message, self._details, self._name) - def _position(self): + def _position(self) -> str: """Return human-readable filename + line number.""" method = f"in {self._methodname}" if self._methodname else "" @@ -437,10 +445,10 @@ def _position(self): def __str__(self): return self.as_string() - def set_line(self, line): + def set_line(self, line) -> None: self._line = line - def as_string(self, *, color=False): + def as_string(self, *, color=False) -> str: """Format the error as a friendly string, optionally with shell coloring.""" pos = self._position() if pos: @@ -459,7 +467,7 @@ def as_string(self, *, color=False): text += "\n" + self._traceback return text - def drop_traceback(self): + def drop_traceback(self: _TError) -> _TError: with _CURRENT_ERROR_NAME.bind(self._name): return self.__class__( severity=self._severity, @@ -480,22 +488,22 @@ def drop_traceback(self): class ErrorLog: """A stream of errors.""" - def __init__(self, src: str): + def __init__(self, src: str) -> None: self._errors = [] # An error filter (initially None) self._filter = None self._src = src - def __len__(self): + def __len__(self) -> int: return len(self._errors) - def __iter__(self): + def __iter__(self) -> Iterator[None]: return iter(self._errors) def __getitem__(self, index): return self._errors[index] - def copy_from(self, errors, stack): + def copy_from(self, errors, stack) -> None: for e in errors: with _CURRENT_ERROR_NAME.bind(e.name): self.error( @@ -507,11 +515,11 @@ def copy_from(self, errors, stack): e.keyword_context, ) - def is_valid_error_name(self, name): + def is_valid_error_name(self, name) -> bool: """Return True iff name was defined in an @error_name() decorator.""" return name in _ERROR_NAMES - def set_error_filter(self, filt): + def set_error_filter(self, filt) -> None: """Set the error filter. Args: @@ -523,19 +531,19 @@ def set_error_filter(self, filt): """ self._filter = filt - def has_error(self): + def has_error(self) -> bool: """Return true iff an Error with SEVERITY_ERROR is present.""" # pylint: disable=protected-access return any(e._severity == SEVERITY_ERROR for e in self._errors) - def _add(self, error): + def _add(self, error) -> None: if self._filter is None or self._filter(error): _log.info("Added error to log: %s\n%s", error.name, error) if _log.isEnabledFor(logging.DEBUG): _log.debug(debug.stack_trace(limit=1).rstrip()) self._errors.append(error) - def warn(self, stack, message, *args): + def warn(self, stack, message, *args) -> None: self._add( Error.with_stack(stack, SEVERITY_WARNING, message % args, src=self._src) ) @@ -549,7 +557,7 @@ def error( bad_call=None, keyword_context=None, line=None, - ): + ) -> None: err = Error.with_stack( stack, SEVERITY_ERROR, @@ -578,7 +586,7 @@ def checkpoint(self): len(checkpoint.errors), ) - def print_to_csv_file(self, fi: IO[str]): + def print_to_csv_file(self, fi: IO[str]) -> None: """Print the errorlog to a csv file.""" csv_file = csv.writer(fi, delimiter=",", lineterminator="\n") for error in self.unique_sorted_errors(): @@ -593,11 +601,11 @@ def print_to_csv_file(self, fi: IO[str]): [error._filename, error._line, error._name, error._message, details] ) - def print_to_file(self, fi: IO[str], *, color: bool = False): + def print_to_file(self, fi: IO[str], *, color: bool = False) -> None: for error in self.unique_sorted_errors(): print(error.as_string(color=color), file=fi) - def unique_sorted_errors(self): + def unique_sorted_errors(self) -> list: """Gets the unique errors in this log, sorted on filename and line.""" unique_errors = {} for error in self._sorted_errors(): @@ -630,13 +638,13 @@ def unique_sorted_errors(self): errors.append(error) return sum(unique_errors.values(), []) - def _sorted_errors(self): + def _sorted_errors(self) -> list: return sorted(self._errors, key=lambda x: (x.filename or "", x.line)) - def print_to_stderr(self, *, color=True): + def print_to_stderr(self, *, color=True) -> None: self.print_to_file(sys.stderr, color=color) - def __str__(self): + def __str__(self) -> str: f = io.StringIO() self.print_to_file(f) return f.getvalue() @@ -654,13 +662,13 @@ def pretty_printer(self) -> pretty_printer_base.PrettyPrinterBase: return self._pp @_error_name("pyi-error") - def pyi_error(self, stack, name, error): + def pyi_error(self, stack, name, error) -> None: self.error( stack, f"Couldn't import pyi for {name!r}", str(error), keyword=name ) @_error_name("attribute-error") - def _attribute_error(self, stack, binding, obj_repr, attr_name): + def _attribute_error(self, stack, binding, obj_repr, attr_name) -> None: """Log an attribute error.""" if len(binding.variable.bindings) > 1: # Joining the printed types rather than merging them before printing @@ -678,7 +686,7 @@ def _attribute_error(self, stack, binding, obj_repr, attr_name): ) @_error_name("not-writable") - def not_writable(self, stack, obj, attr_name): + def not_writable(self, stack, obj, attr_name) -> None: obj_repr = self._pp.print_type(obj) self.error( stack, @@ -688,7 +696,7 @@ def not_writable(self, stack, obj, attr_name): ) @_error_name("module-attr") - def _module_attr(self, stack, module_name, attr_name): + def _module_attr(self, stack, module_name, attr_name) -> None: self.error( stack, f"No attribute {attr_name!r} on module {module_name!r}", @@ -696,7 +704,7 @@ def _module_attr(self, stack, module_name, attr_name): keyword_context=module_name, ) - def attribute_error(self, stack, binding, attr_name): + def attribute_error(self, stack, binding, attr_name) -> None: ep = error_printer.AttributeErrorPrinter(self._pp) recv = ep.print_receiver(binding.data, attr_name) if recv.obj_type == error_printer.BadAttrType.SYMBOL: @@ -710,7 +718,7 @@ def attribute_error(self, stack, binding, attr_name): assert False, recv.obj_type @_error_name("unbound-type-param") - def unbound_type_param(self, stack, obj, attr_name, type_param_name): + def unbound_type_param(self, stack, obj, attr_name, type_param_name) -> None: self.error( stack, f"Can't access attribute {attr_name!r} on {obj.name}", @@ -720,18 +728,18 @@ def unbound_type_param(self, stack, obj, attr_name, type_param_name): ) @_error_name("name-error") - def name_error(self, stack, name, details=None): + def name_error(self, stack, name, details=None) -> None: self.error( stack, f"Name {name!r} is not defined", keyword=name, details=details ) @_error_name("import-error") - def import_error(self, stack, module_name): + def import_error(self, stack, module_name) -> None: self.error( stack, f"Can't find module {module_name!r}.", keyword=module_name ) - def _invalid_parameters(self, stack, message, bad_call): + def _invalid_parameters(self, stack, message, bad_call) -> None: """Log an invalid parameters error.""" ret = error_printer.BadCallPrinter(self._pp, bad_call).print_call_details() details = "".join( @@ -748,7 +756,7 @@ def _invalid_parameters(self, stack, message, bad_call): self.error(stack, message, details, bad_call=bad_call) @_error_name("wrong-arg-count") - def wrong_arg_count(self, stack, name, bad_call): + def wrong_arg_count(self, stack, name, bad_call) -> None: message = "%s expects %d arg(s), got %d" % ( _function_name(name, capitalize=True), bad_call.sig.mandatory_param_count(), @@ -756,7 +764,9 @@ def wrong_arg_count(self, stack, name, bad_call): ) self._invalid_parameters(stack, message, bad_call) - def _get_binary_operation(self, function_name, bad_call): + def _get_binary_operation( + self, function_name, bad_call + ) -> tuple[Any, str, str] | None: """Return (op, left, right) if the function should be treated as a binop.""" maybe_left_operand, _, f = function_name.rpartition(".") # Check that @@ -782,7 +792,7 @@ def _get_binary_operation(self, function_name, bad_call): return f, left_operand, right_operand return None - def wrong_arg_types(self, stack, name, bad_call): + def wrong_arg_types(self, stack, name, bad_call) -> None: """Log [wrong-arg-types].""" operation = self._get_binary_operation(name, bad_call) if operation: @@ -801,7 +811,7 @@ def wrong_arg_types(self, stack, name, bad_call): self._wrong_arg_types(stack, name, bad_call) @_error_name("wrong-arg-types") - def _wrong_arg_types(self, stack, name, bad_call): + def _wrong_arg_types(self, stack, name, bad_call) -> None: """A function was called with the wrong parameter types.""" message = "%s was called with the wrong arguments" % _function_name( name, capitalize=True @@ -809,7 +819,7 @@ def _wrong_arg_types(self, stack, name, bad_call): self._invalid_parameters(stack, message, bad_call) @_error_name("wrong-keyword-args") - def wrong_keyword_args(self, stack, name, bad_call, extra_keywords): + def wrong_keyword_args(self, stack, name, bad_call, extra_keywords) -> None: """A function was called with extra keywords.""" if len(extra_keywords) == 1: message = "Invalid keyword argument {} to {}".format( @@ -822,7 +832,7 @@ def wrong_keyword_args(self, stack, name, bad_call, extra_keywords): self._invalid_parameters(stack, message, bad_call) @_error_name("missing-parameter") - def missing_parameter(self, stack, name, bad_call, missing_parameter): + def missing_parameter(self, stack, name, bad_call, missing_parameter) -> None: """A function call is missing parameters.""" message = "Missing parameter {!r} in call to {}".format( missing_parameter, _function_name(name) @@ -830,7 +840,7 @@ def missing_parameter(self, stack, name, bad_call, missing_parameter): self._invalid_parameters(stack, message, bad_call) @_error_name("not-callable") - def not_callable(self, stack, func, details=None): + def not_callable(self, stack, func, details=None) -> None: """Calling an object that isn't callable.""" if isinstance(func, types.Function) and func.is_overload: prefix = "@typing.overload-decorated " @@ -840,7 +850,7 @@ def not_callable(self, stack, func, details=None): self.error(stack, message, keyword=func.name, details=details) @_error_name("not-indexable") - def not_indexable(self, stack, name, generic_warning=False): + def not_indexable(self, stack, name, generic_warning=False) -> None: message = f"class {name} is not indexable" if generic_warning: self.error( @@ -850,7 +860,7 @@ def not_indexable(self, stack, name, generic_warning=False): self.error(stack, message, keyword=name) @_error_name("not-instantiable") - def not_instantiable(self, stack, cls): + def not_instantiable(self, stack, cls) -> None: """Instantiating an abstract class.""" message = "Can't instantiate {} with abstract methods {}".format( cls.full_name, ", ".join(sorted(cls.abstract_methods)) @@ -858,7 +868,7 @@ def not_instantiable(self, stack, cls): self.error(stack, message) @_error_name("ignored-abstractmethod") - def ignored_abstractmethod(self, stack, cls_name, method_name): + def ignored_abstractmethod(self, stack, cls_name, method_name) -> None: message = f"Stray abc.abstractmethod decorator on method {method_name}" self.error( stack, @@ -867,12 +877,12 @@ def ignored_abstractmethod(self, stack, cls_name, method_name): ) @_error_name("ignored-metaclass") - def ignored_metaclass(self, stack, cls, metaclass): + def ignored_metaclass(self, stack, cls, metaclass) -> None: message = f"Metaclass {metaclass} on class {cls} ignored in Python 3" self.error(stack, message) @_error_name("duplicate-keyword-argument") - def duplicate_keyword(self, stack, name, bad_call, duplicate): + def duplicate_keyword(self, stack, name, bad_call, duplicate) -> None: message = "%s got multiple values for keyword argument %r" % ( _function_name(name), duplicate, @@ -880,10 +890,10 @@ def duplicate_keyword(self, stack, name, bad_call, duplicate): self._invalid_parameters(stack, message, bad_call) @_error_name("invalid-super-call") - def invalid_super_call(self, stack, message, details=None): + def invalid_super_call(self, stack, message, details=None) -> None: self.error(stack, message, details) - def invalid_function_call(self, stack, error): + def invalid_function_call(self, stack, error) -> None: """Log an invalid function call.""" # Make sure method names are prefixed with the class name. if ( @@ -921,7 +931,7 @@ def invalid_function_call(self, stack, error): raise AssertionError(error) @_error_name("base-class-error") - def base_class_error(self, stack, base_var, details=None): + def base_class_error(self, stack, base_var, details=None) -> None: base_cls = self._pp.join_printed_types( self._pp.print_type_of_instance(t) for t in base_var.data ) @@ -933,7 +943,7 @@ def base_class_error(self, stack, base_var, details=None): ) @_error_name("bad-return-type") - def bad_return_type(self, stack, node, bad): + def bad_return_type(self, stack, node, bad) -> None: """Logs a [bad-return-type] error.""" ret = error_printer.MatcherErrorPrinter(self._pp).print_return_types( @@ -954,7 +964,7 @@ def bad_return_type(self, stack, node, bad): self.error(stack, message, "".join(details)) @_error_name("bad-return-type") - def any_return_type(self, stack): + def any_return_type(self, stack) -> None: """Logs a [bad-return-type] error.""" message = "Return type may not be Any" details = [ @@ -964,7 +974,7 @@ def any_return_type(self, stack): self.error(stack, message, "".join(details)) @_error_name("bad-yield-annotation") - def bad_yield_annotation(self, stack, name, annot, is_async): + def bad_yield_annotation(self, stack, name, annot, is_async) -> None: func = ("async " if is_async else "") + f"generator function {name}" actual = self._pp.print_type_of_instance(annot) message = f"Bad return type {actual!r} for {func}" @@ -975,7 +985,7 @@ def bad_yield_annotation(self, stack, name, annot, is_async): self.error(stack, message, details) @_error_name("bad-concrete-type") - def bad_concrete_type(self, stack, node, bad, details=None): + def bad_concrete_type(self, stack, node, bad, details=None) -> None: ret = error_printer.MatcherErrorPrinter(self._pp).print_return_types( node, bad ) @@ -993,7 +1003,7 @@ def bad_concrete_type(self, stack, node, bad, details=None): stack, "Invalid instantiation of generic class", "".join(full_details) ) - def unsupported_operands(self, stack, operator, var1, var2): + def unsupported_operands(self, stack, operator, var1, var2) -> None: left = self._pp.show_variable(var1) right = self._pp.show_variable(var2) details = f"No attribute {operator!r} on {left}" @@ -1002,7 +1012,9 @@ def unsupported_operands(self, stack, operator, var1, var2): self._unsupported_operands(stack, operator, left, right, details=details) @_error_name("unsupported-operands") - def _unsupported_operands(self, stack, operator, *operands, details=None): + def _unsupported_operands( + self, stack, operator, *operands, details=None + ) -> None: """Unsupported operands.""" # `operator` is sometimes the symbol and sometimes the method name, so we # need to check for both here. @@ -1036,7 +1048,7 @@ def invalid_annotation( annot = self._pp.print_type_of_instance(annot) self._invalid_annotation(stack, annot, details, name) - def _print_params_helper(self, param_or_params): + def _print_params_helper(self, param_or_params) -> str: if isinstance(param_or_params, types.BaseValue): return self._pp.print_type_of_instance(param_or_params) else: @@ -1067,7 +1079,7 @@ def wrong_annotation_parameter_count( ) self._invalid_annotation(stack, full_type, details, name=None) - def invalid_ellipses(self, stack, indices, container_name): + def invalid_ellipses(self, stack, indices, container_name) -> None: if indices: details = "Not allowed at {} {} in {}".format( "index" if len(indices) == 1 else "indices", @@ -1092,7 +1104,7 @@ def ambiguous_annotation( self._invalid_annotation(stack, desc, "Must be constant", name) @_error_name("invalid-annotation") - def _invalid_annotation(self, stack, annot_string, details, name): + def _invalid_annotation(self, stack, annot_string, details, name) -> None: """Log the invalid annotation.""" if name is None: suffix = "" @@ -1106,7 +1118,7 @@ def _invalid_annotation(self, stack, annot_string, details, name): ) @_error_name("mro-error") - def mro_error(self, stack, name, mro_seqs, details=None): + def mro_error(self, stack, name, mro_seqs, details=None) -> None: seqs = [] for seq in mro_seqs: seqs.append(f"[{', '.join(cls.name for cls in seq)}]") @@ -1115,7 +1127,7 @@ def mro_error(self, stack, name, mro_seqs, details=None): self.error(stack, msg, keyword=name, details=details) @_error_name("invalid-directive") - def invalid_directive(self, filename, line, message): + def invalid_directive(self, filename, line, message) -> None: self._add( Error( SEVERITY_WARNING, @@ -1127,7 +1139,7 @@ def invalid_directive(self, filename, line, message): ) @_error_name("late-directive") - def late_directive(self, filename, line, name): + def late_directive(self, filename, line, name) -> None: message = f"{name} disabled from here to the end of the file" details = ( "Consider limiting this directive's scope or moving it to the " @@ -1145,11 +1157,11 @@ def late_directive(self, filename, line, name): ) @_error_name("not-supported-yet") - def not_supported_yet(self, stack, feature, details=None): + def not_supported_yet(self, stack, feature, details=None) -> None: self.error(stack, f"{feature} not supported yet", details=details) @_error_name("python-compiler-error") - def python_compiler_error(self, filename, line, message): + def python_compiler_error(self, filename, line, message) -> None: self._add( Error( SEVERITY_ERROR, message, filename=filename, line=line, src=self._src @@ -1157,11 +1169,11 @@ def python_compiler_error(self, filename, line, message): ) @_error_name("recursion-error") - def recursion_error(self, stack, name): + def recursion_error(self, stack, name) -> None: self.error(stack, f"Detected recursion in {name}", keyword=name) @_error_name("redundant-function-type-comment") - def redundant_function_type_comment(self, filename, line): + def redundant_function_type_comment(self, filename, line) -> None: self._add( Error( SEVERITY_ERROR, @@ -1173,13 +1185,13 @@ def redundant_function_type_comment(self, filename, line): ) @_error_name("invalid-function-type-comment") - def invalid_function_type_comment(self, stack, comment, details=None): + def invalid_function_type_comment(self, stack, comment, details=None) -> None: self.error( stack, f"Invalid function type comment: {comment}", details=details ) @_error_name("ignored-type-comment") - def ignored_type_comment(self, filename, line, comment): + def ignored_type_comment(self, filename, line, comment) -> None: self._add( Error( SEVERITY_WARNING, @@ -1191,14 +1203,14 @@ def ignored_type_comment(self, filename, line, comment): ) @_error_name("invalid-typevar") - def invalid_typevar(self, stack, comment, bad_call=None): + def invalid_typevar(self, stack, comment, bad_call=None) -> None: if bad_call: self._invalid_parameters(stack, comment, bad_call) else: self.error(stack, f"Invalid TypeVar: {comment}") @_error_name("invalid-namedtuple-arg") - def invalid_namedtuple_arg(self, stack, badname=None, err_msg=None): + def invalid_namedtuple_arg(self, stack, badname=None, err_msg=None) -> None: if err_msg is None: msg = ( "collections.namedtuple argument %r is not a valid typename or " @@ -1209,16 +1221,16 @@ def invalid_namedtuple_arg(self, stack, badname=None, err_msg=None): self.error(stack, err_msg) @_error_name("bad-function-defaults") - def bad_function_defaults(self, stack, func_name): + def bad_function_defaults(self, stack, func_name) -> None: msg = "Attempt to set %s.__defaults__ to a non-tuple value." self.warn(stack, msg % func_name) @_error_name("bad-slots") - def bad_slots(self, stack, msg): + def bad_slots(self, stack, msg) -> None: self.error(stack, msg) @_error_name("bad-unpacking") - def bad_unpacking(self, stack, num_vals, num_vars): + def bad_unpacking(self, stack, num_vals, num_vars) -> None: prettify = lambda v, label: "%d %s%s" % (v, label, "" if v == 1 else "s") vals_str = prettify(num_vals, "value") vars_str = prettify(num_vars, "variable") @@ -1226,15 +1238,15 @@ def bad_unpacking(self, stack, num_vals, num_vars): self.error(stack, msg, keyword=vals_str) @_error_name("bad-unpacking") - def nondeterministic_unpacking(self, stack): + def nondeterministic_unpacking(self, stack) -> None: self.error(stack, "Unpacking a non-deterministic order iterable.") @_error_name("reveal-type") - def reveal_type(self, stack, node, var): + def reveal_type(self, stack, node, var) -> None: self.error(stack, self._pp.print_var_type(var, node)) @_error_name("assert-type") - def assert_type(self, stack, actual: str, expected: str): + def assert_type(self, stack, actual: str, expected: str) -> None: """Check that a variable type matches its expected value.""" details = f"Expected: {expected}\n Actual: {actual}" self.error(stack, actual, details=details) @@ -1250,7 +1262,7 @@ def annotation_type_mismatch( details=None, *, typed_dict=None, - ): + ) -> None: """Invalid combination of annotation and assignment.""" if annot is None: return @@ -1289,7 +1301,7 @@ def annotation_type_mismatch( self.error(stack, err_msg, details=details) @_error_name("container-type-mismatch") - def container_type_mismatch(self, stack, cls, mutations, name): + def container_type_mismatch(self, stack, cls, mutations, name) -> None: """Invalid combination of annotation and mutation. Args: @@ -1325,18 +1337,18 @@ def container_type_mismatch(self, stack, cls, mutations, name): self.error(stack, err_msg, details=details) @_error_name("invalid-function-definition") - def invalid_function_definition(self, stack, msg, details=None): + def invalid_function_definition(self, stack, msg, details=None) -> None: self.error(stack, msg, details=details) @_error_name("invalid-signature-mutation") - def invalid_signature_mutation(self, stack, func_name, sig): + def invalid_signature_mutation(self, stack, func_name, sig) -> None: sig = self._pp.print_pytd(sig) msg = "Invalid self type mutation in pyi method signature" details = f"{func_name}{sig}" self.error(stack, msg, details) @_error_name("typed-dict-error") - def typed_dict_error(self, stack, obj, name): + def typed_dict_error(self, stack, obj, name) -> None: """Accessing a nonexistent key in a typed dict. Args: @@ -1353,7 +1365,9 @@ def typed_dict_error(self, stack, obj, name): self.error(stack, err_msg) @_error_name("final-error") - def _overriding_final(self, stack, cls, base, name, *, is_method, details): + def _overriding_final( + self, stack, cls, base, name, *, is_method, details + ) -> None: desc = "method" if is_method else "class attribute" msg = ( f"Class {cls.name} overrides final {desc} {name}, " @@ -1361,12 +1375,16 @@ def _overriding_final(self, stack, cls, base, name, *, is_method, details): ) self.error(stack, msg, details=details) - def overriding_final_method(self, stack, cls, base, name, details=None): + def overriding_final_method( + self, stack, cls, base, name, details=None + ) -> None: self._overriding_final( stack, cls, base, name, details=details, is_method=True ) - def overriding_final_attribute(self, stack, cls, base, name, details=None): + def overriding_final_attribute( + self, stack, cls, base, name, details=None + ) -> None: self._overriding_final( stack, cls, base, name, details=details, is_method=False ) @@ -1386,7 +1404,7 @@ def _normalize_signature(self, signature): @_error_name("signature-mismatch") def overriding_signature_mismatch( self, stack, base_signature, class_signature, details=None - ): + ) -> None: """Signature mismatch between overridden and overriding class methods.""" base_signature = self._normalize_signature(base_signature) class_signature = self._normalize_signature(class_signature) @@ -1401,14 +1419,14 @@ def overriding_signature_mismatch( self.error(stack, "Overriding method signature mismatch", details=details) @_error_name("final-error") - def assigning_to_final(self, stack, name, local): + def assigning_to_final(self, stack, name, local) -> None: """Attempting to reassign a variable annotated with Final.""" obj = "variable" if local else "attribute" err_msg = f"Assigning to {obj} {name}, which was annotated with Final" self.error(stack, err_msg) @_error_name("final-error") - def subclassing_final_class(self, stack, base_var, details=None): + def subclassing_final_class(self, stack, base_var, details=None) -> None: base_cls = self._pp.join_printed_types( self._pp.print_type_of_instance(t) for t in base_var.data ) @@ -1420,7 +1438,7 @@ def subclassing_final_class(self, stack, base_var, details=None): ) @_error_name("final-error") - def bad_final_decorator(self, stack, obj, details=None): + def bad_final_decorator(self, stack, obj, details=None) -> None: name = getattr(obj, "name", None) if not name: typ = self._pp.print_type_of_instance(obj) @@ -1430,7 +1448,7 @@ def bad_final_decorator(self, stack, obj, details=None): self.error(stack, msg, details=details) @_error_name("final-error") - def invalid_final_type(self, stack, details=None): + def invalid_final_type(self, stack, details=None) -> None: msg = "Invalid use of typing.Final" details = ( "Final may only be used as the outermost type in assignments " @@ -1439,7 +1457,9 @@ def invalid_final_type(self, stack, details=None): self.error(stack, msg, details=details) @_error_name("match-error") - def match_posargs_count(self, stack, cls, posargs, match_args, details=None): + def match_posargs_count( + self, stack, cls, posargs, match_args, details=None + ) -> None: msg = ( f"{cls.name}() accepts {match_args} positional sub-patterns" f" ({posargs} given)" @@ -1447,38 +1467,38 @@ def match_posargs_count(self, stack, cls, posargs, match_args, details=None): self.error(stack, msg, details=details) @_error_name("match-error") - def bad_class_match(self, stack, obj, details=None): + def bad_class_match(self, stack, obj, details=None) -> None: msg = f"Invalid constructor pattern in match case (not a class): {obj}" self.error(stack, msg, details=details) @_error_name("incomplete-match") - def incomplete_match(self, stack, line, cases, details=None): + def incomplete_match(self, stack, line, cases, details=None) -> None: cases = ", ".join(str(x) for x in cases) msg = f"The match is missing the following cases: {cases}" self.error(stack, msg, details=details, line=line) @_error_name("redundant-match") - def redundant_match(self, stack, case, details=None): + def redundant_match(self, stack, case, details=None) -> None: msg = f"This case has already been covered: {case}." self.error(stack, msg, details=details) @_error_name("paramspec-error") - def paramspec_error(self, stack, details=None): + def paramspec_error(self, stack, details=None) -> None: msg = "ParamSpec error" self.error(stack, msg, details=details) @_error_name("dataclass-error") - def dataclass_error(self, stack, details=None): + def dataclass_error(self, stack, details=None) -> None: msg = "Dataclass error" self.error(stack, msg, details=details) @_error_name("override-error") - def no_overridden_attribute(self, stack, attr): + def no_overridden_attribute(self, stack, attr) -> None: msg = f"Attribute {attr!r} not found on any parent class" self.error(stack, msg) @_error_name("override-error") - def missing_override_decorator(self, stack, attr, parent): + def missing_override_decorator(self, stack, attr, parent) -> None: parent_attr = f"{parent}.{attr}" msg = ( f"Missing @typing.override decorator for {attr!r}, which overrides " @@ -1487,5 +1507,5 @@ def missing_override_decorator(self, stack, attr, parent): self.error(stack, msg) -def get_error_names_set(): +def get_error_names_set() -> set[str]: return _ERROR_NAMES diff --git a/pytype/file_utils.py b/pytype/file_utils.py index dc6c99a04..8c6d4cf5e 100644 --- a/pytype/file_utils.py +++ b/pytype/file_utils.py @@ -5,14 +5,17 @@ import os import re import sys +from typing import Optional, TypeVar from pytype.platform_utils import path_utils +_T0 = TypeVar("_T0") + PICKLE_EXT = ".pickled" -def recursive_glob(path): +def recursive_glob(path: _T0) -> list[_T0]: """Call recursive glob iff ** is in the pattern.""" if "*" not in path: # Glob isn't needed. @@ -32,7 +35,7 @@ def replace_extension(filename, new_extension): return name + "." + new_extension -def makedirs(path): +def makedirs(path) -> None: """Create a nested directory, but don't fail if any of it already exists.""" try: os.makedirs(path) @@ -98,7 +101,7 @@ def expand_globpaths(globpaths, cwd=None): return expand_paths(paths, cwd) -def expand_source_files(filenames, cwd=None): +def expand_source_files(filenames, cwd=None) -> set: """Expand a space-separated string of filenames passed in as sources. This is a helper function for handling command line arguments that specify a @@ -138,7 +141,7 @@ def expand_pythonpath(pythonpath, cwd=None): return [] -def replace_separator(path: str): +def replace_separator(path: str) -> str: """replace `/` with `os.path.sep`, replace `:` with `os.pathsep`.""" if sys.platform == "win32": return path.replace("/", os.path.sep).replace(":", os.pathsep) @@ -146,7 +149,7 @@ def replace_separator(path: str): return path -def is_file_script(filename, directory=None): +def is_file_script(filename, directory=None) -> Optional[bool]: # This is for python files that do not have the .py extension # of course we assume that they start with a shebang file_path = expand_path(filename, directory) diff --git a/pytype/imports/base.py b/pytype/imports/base.py index bc71d3f74..e265968e6 100644 --- a/pytype/imports/base.py +++ b/pytype/imports/base.py @@ -14,10 +14,14 @@ import abc import dataclasses import os +from typing import TypeVar from pytype.pytd import pytd +_TModuleInfo = TypeVar("_TModuleInfo", bound="ModuleInfo") + + # Allow a file to be used as the designated default pyi for blacklisted files DEFAULT_PYI_PATH_SUFFIX = None @@ -26,7 +30,7 @@ PREFIX = "pytd:" -def internal_stub_filename(filename): +def internal_stub_filename(filename) -> str: """Filepath for pytype's internal pytd files.""" return PREFIX + filename @@ -40,10 +44,12 @@ class ModuleInfo: file_exists: bool = True @classmethod - def internal_stub(cls, module_name: str, filename: str): + def internal_stub( + cls: type[_TModuleInfo], module_name: str, filename: str + ) -> _TModuleInfo: return cls(module_name, internal_stub_filename(filename)) - def is_default_pyi(self): + def is_default_pyi(self) -> bool: return self.filename == os.devnull or ( DEFAULT_PYI_PATH_SUFFIX and self.filename.endswith(DEFAULT_PYI_PATH_SUFFIX) diff --git a/pytype/imports/builtin_stubs.py b/pytype/imports/builtin_stubs.py index 71fc6f9d5..64e87c2d6 100644 --- a/pytype/imports/builtin_stubs.py +++ b/pytype/imports/builtin_stubs.py @@ -1,22 +1,28 @@ """Utilities for parsing pytd files for builtins.""" +from typing import Any, TypeVar, Union + from pytype import pytype_source_utils from pytype.imports import base from pytype.platform_utils import path_utils from pytype.pyi import parser +from pytype.pytd import pytd from pytype.pytd import visitors + +_T1 = TypeVar("_T1") + # TODO(rechen): It would be nice to get rid of GetBuiltinsAndTyping, and let the # loader call BuiltinsAndTyping.load directly, but the cache currently prevents # slowdowns in tests that create loaders willy-nilly. Maybe load_pytd.py can # warn if there are more than n loaders in play, at any given time. -_cached_builtins_pytd = [] +_cached_builtins_pytd: list = [] # pylint: disable=invalid-name # We use a mix of camel case and snake case method names in this file. -def InvalidateCache(): +def InvalidateCache() -> None: if _cached_builtins_pytd: del _cached_builtins_pytd[0] @@ -36,19 +42,19 @@ def __getattr__(name: Any) -> Any: ... # If you have a Loader available, use loader.get_default_ast() instead. -def GetDefaultAst(options): +def GetDefaultAst(options) -> pytd.TypeDeclUnit: return parser.parse_string(src=DEFAULT_SRC, options=options) class BuiltinsAndTyping: """The builtins and typing modules, which need to be treated specially.""" - def _parse_predefined(self, name, options): + def _parse_predefined(self, name, options) -> pytd.TypeDeclUnit: _, src = GetPredefinedFile("builtins", name, ".pytd") mod = parser.parse_string(src, name=name, options=options) return mod - def load(self, options): + def load(self, options) -> tuple[Any, Any]: """Read builtins.pytd and typing.pytd, and return the parsed modules.""" t = self._parse_predefined("typing", options) b = self._parse_predefined("builtins", options) @@ -73,7 +79,7 @@ def load(self, options): def GetPredefinedFile( stubs_subdir, module, extension=".pytd", as_package=False -): +) -> tuple[str, Any]: """Get the contents of a predefined PyTD, typically with a file name *.pytd. Arguments: @@ -98,10 +104,12 @@ def GetPredefinedFile( class BuiltinLoader(base.BuiltinLoader): """Load builtins from the pytype source tree.""" - def __init__(self, options): + def __init__(self, options) -> None: self.options = options - def _parse_predefined(self, pytd_subdir, module, as_package=False): + def _parse_predefined( + self, pytd_subdir, module, as_package=False + ) -> pytd.TypeDeclUnit | None: """Parse a pyi/pytd file in the pytype source tree.""" try: filename, src = GetPredefinedFile( @@ -115,7 +123,9 @@ def _parse_predefined(self, pytd_subdir, module, as_package=False): assert ast.name == module return ast - def load_module(self, namespace, module_name): + def load_module( + self, namespace, module_name: _T1 + ) -> tuple[Union[str, _T1], Any]: """Load a stub that ships with pytype.""" mod = self._parse_predefined(namespace, module_name) # For stubs in pytype's stubs/ directory, we use the module name prefixed diff --git a/pytype/imports/module_loader.py b/pytype/imports/module_loader.py index 00a585d8a..25516a415 100644 --- a/pytype/imports/module_loader.py +++ b/pytype/imports/module_loader.py @@ -10,7 +10,7 @@ from pytype.pyi import parser -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) class _PathFinder: @@ -123,7 +123,7 @@ def load_ast(self, mod_info: base.ModuleInfo): else: return self._load_pyi(mod_info) - def log_module_not_found(self, module_name: str): + def log_module_not_found(self, module_name: str) -> None: log.warning( "Couldn't import module %s %r in (path=%r) imports_map: %s", module_name, diff --git a/pytype/imports/pickle_utils.py b/pytype/imports/pickle_utils.py index d2b896dfe..11c3baaf0 100644 --- a/pytype/imports/pickle_utils.py +++ b/pytype/imports/pickle_utils.py @@ -35,6 +35,8 @@ import msgspec from pytype.pytd import pytd from pytype.pytd import serialize_ast +from pytype.pytd.serialize_ast import SerializableAst + Path = Union[str, os.PathLike[str]] @@ -48,14 +50,18 @@ def __init__(self, filename: Path): super().__init__(msg) -Encoder = msgspec.msgpack.Encoder(order="deterministic") -AstDecoder = msgspec.msgpack.Decoder(type=serialize_ast.SerializableAst) -BuiltinsDecoder = msgspec.msgpack.Decoder(type=serialize_ast.ModuleBundle) +Encoder: "Encoder" = msgspec.msgpack.Encoder(order="deterministic") +AstDecoder: msgspec.msgpack.Decoder[SerializableAst] = msgspec.msgpack.Decoder( + type=serialize_ast.SerializableAst +) +BuiltinsDecoder: msgspec.msgpack.Decoder[ + tuple[tuple[str, msgspec.Raw], ...] +] = msgspec.msgpack.Decoder(type=serialize_ast.ModuleBundle) _DecT = TypeVar( "_DecT", serialize_ast.SerializableAst, serialize_ast.ModuleBundle ) -_Dec = msgspec.msgpack.Decoder +_Dec: type[msgspec.msgpack.Decoder] = msgspec.msgpack.Decoder _Serializable = Union[serialize_ast.SerializableAst, serialize_ast.ModuleBundle] diff --git a/pytype/imports/typeshed.py b/pytype/imports/typeshed.py index 9c52216aa..39b71e281 100644 --- a/pytype/imports/typeshed.py +++ b/pytype/imports/typeshed.py @@ -2,9 +2,10 @@ import abc import collections -from collections.abc import Collection, Sequence +from collections.abc import Collection, Generator, Sequence import os import re +from typing import Any from pytype import module_utils from pytype import pytype_source_utils @@ -15,7 +16,7 @@ from pytype.pyi import parser -def _get_module_names_in_path(lister, path, python_version): +def _get_module_names_in_path(lister, path, python_version) -> set: """Get module names for all .pyi files in the given path.""" names = set() try: @@ -66,7 +67,7 @@ def load_file(self, relpath) -> tuple[str, str]: class TypeshedFs(TypeshedStore): """Filesystem-based typeshed store.""" - def __init__(self, *, missing_file=None, open_function=open): + def __init__(self, *, missing_file=None, open_function=open) -> None: self._root = self.get_root() self._open_function = open_function self._missing_file = missing_file @@ -83,7 +84,7 @@ def load_file(self, relpath) -> tuple[str, str]: with self._open_function(filename) as f: return relpath, f.read() - def _readlines(self, unix_relpath): + def _readlines(self, unix_relpath) -> list[str]: relpath = path_utils.join(*unix_relpath.split("/")) _, data = self.load_file(relpath) return data.splitlines() @@ -108,16 +109,16 @@ class InternalTypeshedFs(TypeshedFs): def get_root(self): return pytype_source_utils.get_full_path("typeshed") - def _list_files(self, relpath): + def _list_files(self, relpath) -> Generator[Any, Any, None]: """Lists files recursively in a basedir relative to typeshed root.""" return pytype_source_utils.list_pytype_files( path_utils.join("typeshed", relpath) ) - def list_files(self, relpath): + def list_files(self, relpath) -> list: return list(self._list_files(relpath)) - def file_exists(self, relpath): + def file_exists(self, relpath) -> bool: try: # For a non-par pytype installation, load_text_file will either succeed, # raise FileNotFoundError, or raise IsADirectoryError. @@ -145,7 +146,7 @@ def load_file(self, relpath) -> tuple[str, str]: class ExternalTypeshedFs(TypeshedFs): """Typeshed installation pointed to by TYPESHED_HOME.""" - def get_root(self): + def get_root(self) -> str: home = os.getenv("TYPESHED_HOME") if not home or not path_utils.isdir(home): raise OSError( @@ -154,14 +155,14 @@ def get_root(self): ) return home - def _list_files(self, relpath): + def _list_files(self, relpath) -> Generator[str, Any, None]: """Lists files recursively in a basedir relative to typeshed root.""" return pytype_source_utils.list_files(self.filepath(relpath)) - def list_files(self, relpath): + def list_files(self, relpath) -> list: return list(self._list_files(relpath)) - def file_exists(self, relpath): + def file_exists(self, relpath) -> bool: return path_utils.exists(self.filepath(relpath)) @@ -180,7 +181,7 @@ class Typeshed: # For testing, this file must contain the entry 'stdlib/pytypecanary'. MISSING_FILE = None - def __init__(self, missing_modules: Collection[str] = ()): + def __init__(self, missing_modules: Collection[str] = ()) -> None: """Initializer. Args: @@ -197,11 +198,13 @@ def __init__(self, missing_modules: Collection[str] = ()): self._stdlib_versions = self._load_stdlib_versions() self._third_party_packages = self._load_third_party_packages() - def _load_missing(self): + def _load_missing(self) -> frozenset: lines = self._store.load_missing() return frozenset(line.strip() for line in lines if line) - def _load_stdlib_versions(self): + def _load_stdlib_versions( + self, + ) -> dict[Any, tuple[tuple[int, int], tuple[int, int] | None]]: """Loads the contents of typeshed/stdlib/VERSIONS. VERSIONS lists the stdlib modules with the Python version in which they were @@ -231,7 +234,7 @@ def _load_stdlib_versions(self): versions[module] = minimum, maximum return versions - def _load_third_party_packages(self): + def _load_third_party_packages(self) -> collections.defaultdict: """Loads package and Python version information for typeshed/stubs/. stubs/ contains type information for third-party packages. Each top-level @@ -265,7 +268,7 @@ def missing(self): """Set of known-missing typeshed modules, as strings of paths.""" return self._missing - def get_module_file(self, namespace, module, version): + def get_module_file(self, namespace, module, version) -> tuple[str, str]: """Get the contents of a typeshed .pyi file. Arguments: @@ -353,7 +356,7 @@ def get_pytd_paths(self): for d in (f"stubs{os.path.sep}builtins", f"stubs{os.path.sep}stdlib") ] - def _list_modules(self, path, python_version): + def _list_modules(self, path, python_version) -> Generator[Any, Any, None]: """Lists modules for _get_module_names_in_path.""" for filename in self._store.list_files(path): if filename in ("VERSIONS", "METADATA.toml"): @@ -369,7 +372,7 @@ def _list_modules(self, path, python_version): continue yield filename - def _get_missing_modules(self): + def _get_missing_modules(self) -> set[str]: """Gets module names from the `missing` list.""" module_names = set() for f in self.missing: @@ -383,7 +386,7 @@ def _get_missing_modules(self): module_names.add(filename.replace(os.path.sep, ".")) return module_names - def get_all_module_names(self, python_version): + def get_all_module_names(self, python_version) -> set: """Get the names of all modules in typeshed or bundled with pytype.""" module_names = set() for abspath in self.get_typeshed_paths(): @@ -403,7 +406,7 @@ def get_all_module_names(self, python_version): assert "ctypes" in module_names # sanity check return module_names - def read_blacklist(self): + def read_blacklist(self) -> Generator[str, Any, None]: """Read the typeshed blacklist.""" lines = self._store.load_pytype_blocklist() for line in lines: @@ -413,7 +416,7 @@ def read_blacklist(self): if line: yield line - def blacklisted_modules(self): + def blacklisted_modules(self) -> Generator[Any, Any, None]: """Return the blacklist, as a list of module names. E.g. ["x", "y.z"].""" for path in self.read_blacklist(): # E.g. ["stdlib", "html", "parser.pyi"] @@ -427,7 +430,7 @@ def blacklisted_modules(self): yield mod -def _get_typeshed(missing_modules): +def _get_typeshed(missing_modules) -> Typeshed: """Get a Typeshed instance.""" try: return Typeshed(missing_modules) @@ -441,12 +444,14 @@ def _get_typeshed(missing_modules): class TypeshedLoader(base.BuiltinLoader): """Load modules from typeshed.""" - def __init__(self, options, missing_modules): + def __init__(self, options, missing_modules) -> None: self.options = options self.typeshed = _get_typeshed(missing_modules) # TODO(mdemello): Inject options.open_function into self.typeshed - def load_module(self, namespace, module_name): + def load_module( + self, namespace, module_name + ): """Load and parse a *.pyi from typeshed. Args: diff --git a/pytype/imports_map.py b/pytype/imports_map.py index 47c9bcc02..577b58125 100644 --- a/pytype/imports_map.py +++ b/pytype/imports_map.py @@ -17,11 +17,11 @@ class ImportsMap: items: Mapping[str, str] = dataclasses.field(default_factory=dict) unused: Sequence[str] = dataclasses.field(default_factory=list) - def __getitem__(self, key: str): + def __getitem__(self, key: str) -> str: return self.items[key] - def __contains__(self, key: str): + def __contains__(self, key: str) -> bool: return key in self.items - def __len__(self): + def __len__(self) -> int: return len(self.items) diff --git a/pytype/imports_map_loader.py b/pytype/imports_map_loader.py index 1c9f23ae9..04cbeff21 100644 --- a/pytype/imports_map_loader.py +++ b/pytype/imports_map_loader.py @@ -7,7 +7,7 @@ from pytype import imports_map from pytype.platform_utils import path_utils -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) # Type aliases. @@ -18,7 +18,7 @@ class ImportsMapBuilder: """Build an imports map from (short_path, path) pairs.""" - def __init__(self, options): + def __init__(self, options) -> None: self.options = options def _read_from_file(self, path) -> list[_ItemType]: diff --git a/pytype/inspect/graph.py b/pytype/inspect/graph.py index 88d0b5c59..f04c91d54 100644 --- a/pytype/inspect/graph.py +++ b/pytype/inspect/graph.py @@ -6,7 +6,7 @@ import networkx as nx -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) def obj_key(n): @@ -14,29 +14,29 @@ def obj_key(n): return n.__class__.__name__ + str(nid) -def obj_repr(n): +def obj_repr(n) -> str: return repr(n.data)[:10] class TypeGraph: """Networkx graph builder.""" - def __init__(self, program, ignored, only_cfg=False): + def __init__(self, program, ignored, only_cfg=False) -> None: self.graph = nx.MultiDiGraph() self._add_cfg(program, ignored) if not only_cfg: self._add_variables(program, ignored) - def add_node(self, obj, **kwargs): + def add_node(self, obj, **kwargs) -> None: self.graph.add_node(obj_key(obj), **kwargs) - def add_edge(self, obj1, obj2, **kwargs): + def add_edge(self, obj1, obj2, **kwargs) -> None: self.graph.add_edge(obj_key(obj1), obj_key(obj2), **kwargs) def to_dot(self): return nx.nx_pydot.to_pydot(self.graph).to_string() - def _add_cfg(self, program, ignored): + def _add_cfg(self, program, ignored) -> None: """Add program cfg nodes.""" for node in program.cfg_nodes: @@ -48,7 +48,7 @@ def _add_cfg(self, program, ignored): for other in node.outgoing: self.add_edge(node, other, penwidth=2.0) - def _add_variables(self, program, ignored): + def _add_variables(self, program, ignored) -> None: """A dd program variables and bindings.""" def _is_constant(val): @@ -86,7 +86,7 @@ def _is_constant(val): self.add_edge(src, srcs, color="lightblue", weight=2) -def write_svg_from_dot(svg_file, dot): +def write_svg_from_dot(svg_file, dot) -> None: with subprocess.Popen( ["/usr/bin/dot", "-T", "svg", "-o", svg_file], stdin=subprocess.PIPE, diff --git a/pytype/io.py b/pytype/io.py index 47440c430..2c1303a1e 100644 --- a/pytype/io.py +++ b/pytype/io.py @@ -6,6 +6,7 @@ import os import sys import traceback +from typing import Any, Callable import libcst from pytype import __version__ @@ -28,7 +29,7 @@ from pytype.rewrite import analyze as rewrite_analyze -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) # Webpage explaining the pytype error codes @@ -52,7 +53,7 @@ def read_source_file(input_filename, open_function=open): raise utils.UsageError(f"Could not load input file {input_filename}") from e -def _set_verbosity_from(posarg): +def _set_verbosity_from(posarg) -> Callable[[Any], Any]: """Decorator to set the verbosity for a function that takes an options arg. Assumes that the function has an argument named `options` that is a @@ -161,7 +162,9 @@ def _output_ast( return result -def generate_pyi(src, options=None, loader=None): +def generate_pyi( + src, options=None, loader=None +) -> tuple[analyze.Analysis, str]: """Run the inferencer on a string of source code, producing output. Args: @@ -248,7 +251,7 @@ def check_or_generate_pyi(options) -> AnalysisResult: return AnalysisResult(ctx, ast, result) -def _write_pyi_output(options, contents, filename): +def _write_pyi_output(options, contents, filename) -> None: assert filename if filename == "-": sys.stdout.write(contents) @@ -348,7 +351,7 @@ def write_pickle(ast, options, loader=None): ) -def print_error_doc_url(errorlog): +def print_error_doc_url(errorlog) -> None: names = {e.name for e in errorlog} if names: doclink = f"\nFor more details, see {ERROR_DOC_URL}" @@ -388,12 +391,14 @@ def parse_pyi(options): return ast -def get_pytype_version(): +def get_pytype_version() -> str: return __version__.__version__ @contextlib.contextmanager -def wrap_pytype_exceptions(exception_type, filename=""): +def wrap_pytype_exceptions( + exception_type, filename="" +): """Catch pytype errors and reraise them as a single exception type. NOTE: This will also wrap non-pytype errors thrown within the body of the diff --git a/pytype/load_pytd.py b/pytype/load_pytd.py index b314e8d01..4f791f8fe 100644 --- a/pytype/load_pytd.py +++ b/pytype/load_pytd.py @@ -6,6 +6,7 @@ import functools import logging import os +from typing import Any, TypeVar from pytype import file_utils from pytype import module_utils @@ -20,15 +21,26 @@ from pytype.pytd import pytd_utils from pytype.pytd import serialize_ast from pytype.pytd import visitors +from pytype.pytd.pytd import TypeDeclUnit -log = logging.getLogger(__name__) + +_ModuleNameType: type[str] +_AliasNameType: type[str] +_NameType: type[str] + +_T0 = TypeVar("_T0") +_T1 = TypeVar("_T1") +_TModule = TypeVar("_TModule", bound="Module") +_TPickledPyiLoader = TypeVar("_TPickledPyiLoader", bound="PickledPyiLoader") + +log: logging.Logger = logging.getLogger(__name__) # Always load this module from typeshed, even if we have it in the imports map -_ALWAYS_PREFER_TYPESHED = frozenset({"typing_extensions"}) +_ALWAYS_PREFER_TYPESHED: frozenset[str] = frozenset({"typing_extensions"}) # Type alias -_AST = pytd.TypeDeclUnit -ModuleInfo = imports_base.ModuleInfo +_AST: type[TypeDeclUnit] = pytd.TypeDeclUnit +ModuleInfo: type[imports_base.ModuleInfo] = imports_base.ModuleInfo def create_loader(options, missing_modules=()): @@ -101,7 +113,7 @@ def __init__( metadata=None, pickle=None, has_unresolved_pointers=True, - ): + ) -> None: self.module_name = module_name self.filename = filename self.ast = ast @@ -111,14 +123,14 @@ def __init__( # pylint: enable=redefined-outer-name - def needs_unpickling(self): + def needs_unpickling(self) -> bool: return bool(self.pickle) def is_package(self): return _is_package(self.filename) @classmethod - def resolved_internal_stub(cls, name, mod_ast): + def resolved_internal_stub(cls: type[_TModule], name, mod_ast) -> _TModule: return cls( name, imports_base.internal_stub_filename(name), @@ -130,18 +142,18 @@ def resolved_internal_stub(cls, name, mod_ast): class BadDependencyError(Exception): """If we can't resolve a module referenced by the one we're trying to load.""" - def __init__(self, module_error, src=None): + def __init__(self, module_error, src=None) -> None: referenced = f", referenced from {src!r}" if src else "" super().__init__(module_error + referenced) - def __str__(self): + def __str__(self) -> str: return str(self.args[0]) class _ModuleMap: """A map of fully qualified module name -> Module.""" - def __init__(self, options, modules): + def __init__(self, options, modules) -> None: self.options = options self._modules: dict[str, Module] = modules or self._base_modules() if self._modules["builtins"].needs_unpickling(): @@ -150,16 +162,16 @@ def __init__(self, options, modules): self._unpickle_module(self._modules["typing"]) self._concatenated = None - def __getitem__(self, key): + def __getitem__(self, key) -> Module: return self._modules[key] - def __setitem__(self, key, val): + def __setitem__(self, key, val) -> None: self._modules[key] = val - def __delitem__(self, key): + def __delitem__(self, key) -> None: del self._modules[key] - def __contains__(self, key): + def __contains__(self, key) -> bool: return key in self._modules def items(self): @@ -199,7 +211,7 @@ def get_resolved_modules(self) -> dict[str, ResolvedModule]: ) return resolved_modules - def _base_modules(self): + def _base_modules(self) -> dict[str, Any]: bltins, typing = builtin_stubs.GetBuiltinsAndTyping( parser.PyiOptions.from_toplevel_options(self.options) ) @@ -208,7 +220,7 @@ def _base_modules(self): "typing": Module.resolved_internal_stub("typing", typing), } - def _unpickle_module(self, module): + def _unpickle_module(self, module) -> None: """Unpickle a pickled ast and its dependencies.""" if not module.pickle: return @@ -257,14 +269,14 @@ def concat_all(self): self._concatenated = pytd_utils.Concat(*self.defined_asts(), name="") return self._concatenated - def invalidate_concatenated(self): + def invalidate_concatenated(self) -> None: self._concatenated = None class _Resolver: """Resolve symbols in a pytd tree.""" - def __init__(self, builtins_ast): + def __init__(self, builtins_ast) -> None: self.builtins_ast = builtins_ast self.allow_singletons = False @@ -316,8 +328,8 @@ def resolve_external_types(self, mod_ast, module_map, aliases, *, mod_name): raise BadDependencyError(key, name) from key_error def resolve_module_alias( - self, name, *, lookup_ast=None, lookup_ast_name=None - ): + self, name: _T0, *, lookup_ast=None, lookup_ast_name=None + ) -> _T0: """Check if a given name is an alias and resolve it if so.""" # name is bare, but aliases are stored as "ast_name.alias". if lookup_ast is None: @@ -329,11 +341,11 @@ def resolve_module_alias( key = f"{ast_name}.{cur_name}" value = aliases.get(key) if isinstance(value, pytd.Module): - return value.module_name + name[len(cur_name) :] + return value.module_name + name[len(cur_name) :] # pytype: disable=attribute-error cur_name, _, _ = cur_name.rpartition(".") return name - def verify(self, mod_ast, *, mod_name=None): + def verify(self, mod_ast, *, mod_name=None) -> None: try: mod_ast.Visit(visitors.VerifyLookup(ignore_late_types=True)) except ValueError as e: @@ -414,7 +426,7 @@ class Loader: typing: The typing ast. """ - def __init__(self, options, modules=None, missing_modules=()): + def __init__(self, options, modules=None, missing_modules=()) -> None: self.options = options self._modules = _ModuleMap(options, modules) self.builtins = self._modules["builtins"].ast @@ -442,12 +454,12 @@ def _builtin_loader(self): def _pyi_options(self): return parser.PyiOptions.from_toplevel_options(self.options) - def get_default_ast(self): + def get_default_ast(self) -> TypeDeclUnit: return builtin_stubs.GetDefaultAst( parser.PyiOptions.from_toplevel_options(self.options) ) - def save_to_pickle(self, filename): + def save_to_pickle(self, filename) -> None: """Save to a pickle. See PickledPyiLoader.load_from_pickle for reverse.""" # We assume that the Loader is in a consistent state here. In particular, we # assume that for every module in _modules, all the transitive dependencies @@ -554,7 +566,7 @@ def _try_import_prefix(self, name: str) -> _AST | None: def _load_ast_dependencies( self, dependencies, lookup_ast, lookup_ast_name=None - ): + ) -> None: """Fill in all ClassType.cls pointers and load reexported modules.""" ast_name = lookup_ast_name or lookup_ast.name for dep_name in dependencies: @@ -629,7 +641,7 @@ def _resolve_external_types(self, mod_ast, lookup_ast=None): ) return mod_ast - def _resolve_classtype_pointers(self, mod_ast, *, lookup_ast=None): + def _resolve_classtype_pointers(self, mod_ast, *, lookup_ast=None) -> None: module_map = self._modules.get_module_map() module_map[""] = lookup_ast or mod_ast # The module itself (local lookup) mod_ast.Visit(visitors.FillInLocalPointers(module_map)) @@ -653,7 +665,7 @@ def resolve_ast(self, ast): # NOTE: Modules of dependencies will be loaded into the cache return self.resolve_pytd(ast, ast) - def _resolve_classtype_pointers_for_all_modules(self): + def _resolve_classtype_pointers_for_all_modules(self) -> None: for module in self._modules.values(): if module.has_unresolved_pointers: self._resolve_classtype_pointers(module.ast) @@ -745,11 +757,11 @@ def finish_and_verify_ast(self, mod_ast): self._resolver.verify(mod_ast) return mod_ast - def add_module_prefixes(self, module_name): + def add_module_prefixes(self, module_name) -> None: for prefix in module_utils.get_all_prefixes(module_name): self._prefixes.add(prefix) - def has_module_prefix(self, prefix): + def has_module_prefix(self, prefix) -> bool: return prefix in self._prefixes def _load_builtin(self, namespace, module_name): @@ -835,7 +847,7 @@ def _import_module_by_name(self, module_name) -> _AST | None: def concat_all(self): return self._modules.concat_all() - def get_resolved_modules(self): + def get_resolved_modules(self) -> dict[str, ResolvedModule]: """Gets a name -> ResolvedModule map of the loader's resolved modules.""" return self._modules.get_resolved_modules() @@ -855,7 +867,9 @@ class PickledPyiLoader(Loader): """A Loader which always loads pickle instead of PYI, for speed.""" @classmethod - def load_from_pickle(cls, filename, options, missing_modules=()): + def load_from_pickle( + cls: type[_TPickledPyiLoader], filename, options, missing_modules=() + ) -> _TPickledPyiLoader: """Load a pytd module from a pickle file.""" items = pickle_utils.LoadBuiltins( filename, compress=True, open_function=options.open_function diff --git a/pytype/main.py b/pytype/main.py index fefcd9a12..da9ef3caf 100755 --- a/pytype/main.py +++ b/pytype/main.py @@ -19,13 +19,13 @@ from pytype.imports import typeshed -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) class _ProfileContext: """A context manager for optionally profiling code.""" - def __init__(self, output_path): + def __init__(self, output_path) -> None: """Initialize. Args: @@ -35,17 +35,17 @@ def __init__(self, output_path): self._output_path = output_path self._profile = cProfile.Profile() if self._output_path else None - def __enter__(self): + def __enter__(self) -> None: if self._profile: self._profile.enable() - def __exit__(self, exc_type, exc_value, traceback): # pylint: disable=redefined-outer-name + def __exit__(self, exc_type, exc_value, traceback) -> None: # pylint: disable=redefined-outer-name if self._profile: self._profile.disable() self._profile.dump_stats(self._output_path) -def _generate_builtins_pickle(options): +def _generate_builtins_pickle(options) -> None: """Create a pickled file with the standard library (typeshed + builtins).""" loader = load_pytd.create_loader(options) t = typeshed.Typeshed() @@ -57,7 +57,7 @@ def _generate_builtins_pickle(options): loader.save_to_pickle(options.generate_builtins) -def _expand_args(argv): +def _expand_args(argv) -> list: """Returns argv with flagfiles expanded. A flagfile is an argument starting with "@". The remainder of the argument is @@ -82,7 +82,7 @@ def _expand_single_arg(arg, result): return expanded_args -def _fix_spaces(argv): +def _fix_spaces(argv) -> list: """Returns argv with unescaped spaces in paths fixed. This is needed for the analyze_project tool, which uses ninja to run pytype diff --git a/pytype/matcher.py b/pytype/matcher.py index 66508a95c..5497dde6a 100644 --- a/pytype/matcher.py +++ b/pytype/matcher.py @@ -1,11 +1,10 @@ """Matching logic for abstract values.""" import collections -from collections.abc import Iterable import contextlib import dataclasses import logging -from typing import Any, cast +from typing import Callable, Generator, Optional, TypeVar, Any, cast from pytype import datatypes from pytype import utils @@ -23,8 +22,11 @@ from pytype.typegraph import cfg from pytype.types import types +_T3 = TypeVar("_T3") +_TGoodMatch = TypeVar("_TGoodMatch", bound="GoodMatch") -log = logging.getLogger(__name__) + +log: logging.Logger = logging.getLogger(__name__) _SubstType = datatypes.AliasingDict[str, cfg.Variable] _ViewType = datatypes.AccessTrackingDict[cfg.Variable, cfg.Binding] @@ -42,7 +44,7 @@ def _is_callback_protocol(typ): ) -def _compute_superset_info(subst_key1, subst_key2): +def _compute_superset_info(subst_key1, subst_key2) -> tuple[bool, bool]: """Compute whether subst_key1 is a superset of subst_key2 and vice versa.""" # Since repeatedly iterating over subst keys is slow, we do both computations # in one loop. @@ -76,7 +78,7 @@ class GoodMatch: subst: _SubstType @classmethod - def default(cls): + def default(cls: type[_TGoodMatch]) -> _TGoodMatch: return cls(datatypes.AccessTrackingDict(), datatypes.HashableDict()) @@ -109,14 +111,14 @@ class MatchResult: class _UniqueMatches: """A collection of matches that discards duplicates.""" - def __init__(self, node, keep_all_views): + def __init__(self, node, keep_all_views) -> None: self._node = node self._keep_all_views = keep_all_views self._data: dict[ _ViewKeyType, list[tuple[_SubstKeyType, _ViewType, _SubstType]] ] = collections.defaultdict(list) - def insert(self, view, subst): + def insert(self, view, subst) -> None: """Insert a subst with associated data.""" if self._keep_all_views: view_key = tuple( @@ -154,7 +156,7 @@ def insert(self, view, subst): else: self._data[view_key].append(data_item) - def unique(self) -> Iterable[tuple[_ViewType, _SubstType]]: + def unique(self): for values in self._data.values(): for _, view, subst in values: yield (view, subst) @@ -163,11 +165,11 @@ def unique(self) -> Iterable[tuple[_ViewType, _SubstType]]: class _TypeParams: """Collection of TypeParameter objects encountered during matching.""" - def __init__(self): + def __init__(self) -> None: self.seen = set() self._mutually_exclusive = collections.defaultdict(set) - def add_mutually_exclusive_groups(self, groups): + def add_mutually_exclusive_groups(self, groups) -> None: """Adds groups of mutually exclusive type parameters. For example, [{"T1", "T2"}, {"T3", "T4"}] would mean that the following @@ -182,7 +184,7 @@ def add_mutually_exclusive_groups(self, groups): for name in group: self._mutually_exclusive[name].update(mutually_exclusive) - def has_mutually_exclusive(self, name, subst): + def has_mutually_exclusive(self, name, subst) -> bool: """Whether 'subst' has a param that is mutually exclusive with 'name'.""" return bool(self._mutually_exclusive[name].intersection(subst)) @@ -190,7 +192,7 @@ def has_mutually_exclusive(self, name, subst): class AbstractMatcher(utils.ContextWeakrefMixin): """Matcher for abstract values.""" - def __init__(self, node, ctx): + def __init__(self, node, ctx) -> None: super().__init__(ctx) self._node = node self._protocol_cache = set() @@ -213,13 +215,15 @@ def __init__(self, node, ctx): self._reset_errors() - def _reset_errors(self): + def _reset_errors(self) -> None: self._protocol_error = None self._noniterable_str_error = None self._typed_dict_error = None @contextlib.contextmanager - def _track_partially_matched_protocols(self): + def _track_partially_matched_protocols( + self, + ): """Context manager for handling the protocol cache. Some protocols have methods that return instances of the protocol, e.g. @@ -234,7 +238,7 @@ def _track_partially_matched_protocols(self): yield self._protocol_cache = old_protocol_cache - def _error_details(self): + def _error_details(self) -> error_types.MatcherErrorDetails: """Package up additional error details.""" return error_types.MatcherErrorDetails( protocol=self._protocol_error, @@ -353,7 +357,7 @@ def compute_one_match( if subst is None: if self._node.CanHaveCombination(list(view.values())): bad_matches.append( - BadMatch( + BadMatch( # pytype: disable=wrong-arg-types view=view, expected=self._get_bad_type(name, other_type), actual=var, @@ -428,7 +432,7 @@ def match_var_against_type(self, var, other_type, subst, view): """Match a variable against a type.""" self._reset_errors() if var.bindings: - return self._match_value_against_type(view[var], other_type, subst, view) + return self._match_value_against_type(view[var], other_type, subst, view) # pytype: disable=wrong-arg-types else: # Empty set of values. The "nothing" type. if isinstance(other_type, abstract.TupleClass): other_type = other_type.get_formal_type_parameter(abstract_utils.T) @@ -978,7 +982,7 @@ def _mutate_type_parameters(self, params, value, subst): new_subst = {p.full_name: value.to_variable(self._node) for p in params} return self._merge_substs(subst, [new_subst]) - def _get_param_matcher(self, callable_type): + def _get_param_matcher(self, callable_type) -> Callable[[Any, Any, Any], Any]: """Helper for matching the parameters of a callable. Args: @@ -1171,7 +1175,7 @@ def _merge_matches( def _match_subst_against_subst( self, old_subst, new_subst, type_param_map, has_self - ): + ) -> Optional[datatypes.AliasingDict]: subst = datatypes.AliasingDict(aliases=old_subst.aliases) for t in new_subst: if t not in old_subst or not old_subst[t].bindings: @@ -1405,7 +1409,7 @@ def assert_classes_match(cls1, cls2): def _match_instance_param_against_class_param( self, instance_param, class_param, subst, view - ): + ) -> tuple[Any, Any]: if instance_param.bindings and instance_param not in view: binding = instance_param.bindings[0] view = view.copy() @@ -1444,8 +1448,8 @@ def _match_instance_parameters(self, left, instance, other_type, subst, view): return subst def _match_fiddle_instance_against_bare_type( - self, left, instance, other_type, subst, view - ): + self, left, instance, other_type, subst: _T3, view + ) -> Optional[_T3]: """Match a fiddle instance against an unsubscripted buildable pytd type.""" assert isinstance(instance, fiddle_overlay.Buildable) assert isinstance(other_type, abstract.PyTDClass) @@ -1654,7 +1658,7 @@ def _match_typed_dict_against_dict( ) return subst - def _get_attribute_names(self, left): + def _get_attribute_names(self, left) -> set[str]: """Get the attributes implemented (or implicit) on a type.""" left_attributes = set() if isinstance(left, abstract.Module): @@ -1775,7 +1779,7 @@ def _resolve_function_attribute_var( else: return attribute, False - def _is_native_callable(self, val): + def _is_native_callable(self, val) -> bool: return isinstance(val, abstract.NativeFunction) and isinstance( val.func.__self__, abstract.CallableClass ) @@ -1784,7 +1788,7 @@ def _resolve_function_attribute_value( self, attr: abstract.BaseValue, unbind: bool ) -> abstract.BaseValue: if self._is_native_callable(attr): - sig = function.Signature.from_callable(attr.func.__self__) + sig = function.Signature.from_callable(attr.func.__self__) # pytype: disable=attribute-error if unbind: sig = sig.prepend_parameter("self", self.ctx.convert.unsolvable) return abstract.SimpleFunction(sig, self.ctx) @@ -1814,7 +1818,9 @@ def _get_type(self, value): return abstract.ParameterizedClass(cls, parameters, self.ctx) return cls - def _get_attribute_types(self, other_type, attribute): + def _get_attribute_types( + self, other_type, attribute + ) -> Generator[Any, Any, None]: if not abstract_utils.is_callable(attribute): typ = self._get_type(attribute) if typ: @@ -1939,7 +1945,7 @@ def _match_protocol_attribute(self, left, other_type, attribute, subst, view): return None return self._merge_substs(subst, new_substs) - def _discard_ambiguous_values(self, values): + def _discard_ambiguous_values(self, values) -> list: # TODO(rechen): For type parameter instances, we should extract the concrete # value from v.instance so that we can check it, rather than ignoring the # value altogether. @@ -1959,7 +1965,7 @@ def _discard_ambiguous_values(self, values): concrete_values.append(v) return concrete_values - def _satisfies_single_type(self, values): + def _satisfies_single_type(self, values) -> bool: """Enforce that the variable contains only one concrete type.""" class_names = {v.cls.full_name for v in values} for compat_name, name in self._compatible_builtins: @@ -1968,7 +1974,7 @@ def _satisfies_single_type(self, values): # We require all occurrences to be of the same type, no subtyping allowed. return len(class_names) <= 1 - def _satisfies_common_superclass(self, values): + def _satisfies_common_superclass(self, values) -> bool: """Enforce that the variable's values share a superclass below object.""" common_classes = None object_in_values = False @@ -1996,7 +2002,7 @@ def _satisfies_common_superclass(self, values): return False return True - def _satisfies_noniterable_str(self, left, other_type): + def _satisfies_noniterable_str(self, left, other_type) -> bool: """Enforce a str to NOT be matched against a conflicting iterable type.""" conflicting_iter_types = [ "typing.Iterable", diff --git a/pytype/metrics.py b/pytype/metrics.py index 498ba7c7c..794a77b6b 100644 --- a/pytype/metrics.py +++ b/pytype/metrics.py @@ -24,6 +24,7 @@ def bar(n): import re import time import types +from typing import Any, Optional try: import tracemalloc # pylint: disable=g-import-not-at-top @@ -37,16 +38,16 @@ def bar(n): # need to write custom JsonEncoder and JsonDecoder classes per Metric subclass. # Register metric types for deserialization. -_METRIC_TYPES = {} +_METRIC_TYPES: dict = {} # Map from metric name to Metric object. -_registered_metrics = {} +_registered_metrics: dict = {} # Whether metrics should be collected. -_enabled = False +_enabled: Any = False -def reset(): +def reset() -> None: """Resets this module to its initial state.""" _METRIC_TYPES.clear() _registered_metrics.clear() @@ -73,12 +74,12 @@ def _deserialize(typ, payload): return out -def _serialize(obj): +def _serialize(obj) -> list: """Return a json-serializable form of object.""" return [obj.__class__.__name__, vars(obj)] -def dump_all(objs, fp): +def dump_all(objs, fp) -> None: """Write a list of metrics to a json file.""" json.dump([_serialize(x) for x in objs], fp) @@ -89,15 +90,15 @@ def load_all(fp): return [_deserialize(*x) for x in metrics] -_METRIC_NAME_RE = re.compile(r"^[a-zA-Z_]\w+$") +_METRIC_NAME_RE: re.Pattern = re.compile(r"^[a-zA-Z_]\w+$") -def _validate_metric_name(name): +def _validate_metric_name(name) -> None: if _METRIC_NAME_RE.match(name) is None: raise ValueError(f"Illegal metric name: {name}") -def _prepare_for_test(enabled=True): +def _prepare_for_test(enabled=True) -> None: """Setup metrics collection for a test.""" _registered_metrics.clear() global _enabled @@ -107,7 +108,7 @@ def _prepare_for_test(enabled=True): _platform_timer = time.time if os.name == "nt" else time.process_time -def get_cpu_clock(): +def get_cpu_clock() -> float: """Returns CPU clock to keep compatibility with various Python versions.""" return _platform_timer() @@ -132,7 +133,7 @@ def get_metric(name, constructor, *args, **kwargs): return constructor(name, *args, **kwargs) -def get_report(): +def get_report() -> str: """Return a string listing all metrics, one per line.""" lines = [ str(_registered_metrics[n]) + "\n" for n in sorted(_registered_metrics) @@ -140,7 +141,7 @@ def get_report(): return "".join(lines) -def merge_from_file(metrics_file): +def merge_from_file(metrics_file) -> None: """Merge metrics recorded in another file into the current metrics.""" for metric in load_all(metrics_file): existing = _registered_metrics.get(metric.name) @@ -156,7 +157,7 @@ def merge_from_file(metrics_file): class Metric(metaclass=_RegistryMeta): """Abstract base class for metrics.""" - def __init__(self, name): + def __init__(self, name) -> None: """Initialize the metric and register it under the specified name.""" if name is None: # We do not want to register this metric (e.g. we are deserializing a @@ -181,18 +182,18 @@ def _merge(self, other): """Merge data from another metric of the same type.""" raise NotImplementedError - def __str__(self): + def __str__(self) -> str: return f"{self._name}: {self._summary()}" class Counter(Metric): """A monotonically increasing metric.""" - def __init__(self, name): + def __init__(self, name) -> None: super().__init__(name) self._total = 0 - def inc(self, count=1): + def inc(self, count=1) -> None: """Increment the metric by the specified amount.""" if count < 0: raise ValueError("Counter must be monotonically increasing.") @@ -200,10 +201,10 @@ def inc(self, count=1): return self._total += count - def _summary(self): + def _summary(self) -> str: return str(self._total) - def _merge(self, other): + def _merge(self, other) -> None: # pylint: disable=protected-access self._total += other._total @@ -211,17 +212,17 @@ def _merge(self, other): class StopWatch(Metric): """A counter that measures the time spent in a "with" statement.""" - def __enter__(self): + def __enter__(self) -> None: self._start_time = get_cpu_clock() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: self._total = get_cpu_clock() - self._start_time del self._start_time - def _summary(self): + def _summary(self) -> str: return f"{self._total:f} seconds" - def _merge(self, other): + def _merge(self, other) -> None: # pylint: disable=protected-access self._total += other._total @@ -229,38 +230,38 @@ def _merge(self, other): class ReentrantStopWatch(Metric): """A watch that supports being called multiple times and recursively.""" - def __init__(self, name): + def __init__(self, name) -> None: super().__init__(name) self._time = 0 self._calls = 0 - def __enter__(self): + def __enter__(self) -> None: if not self._calls: self._start_time = get_cpu_clock() self._calls += 1 - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: self._calls -= 1 if not self._calls: self._time += get_cpu_clock() - self._start_time del self._start_time - def _merge(self, other): + def _merge(self, other) -> None: self._time += other._time # pylint: disable=protected-access - def _summary(self): + def _summary(self) -> str: return f"time spend below this StopWatch: {self._time}" class MapCounter(Metric): """A set of related counters keyed by an arbitrary string.""" - def __init__(self, name): + def __init__(self, name) -> None: super().__init__(name) self._counts = {} self._total = 0 - def inc(self, key, count=1): + def inc(self, key, count=1) -> None: """Increment the metric by the specified amount. Args: @@ -277,13 +278,13 @@ def inc(self, key, count=1): self._counts[key] = self._counts.get(key, 0) + count self._total += count - def _summary(self): + def _summary(self) -> str: details = ", ".join( ["%s=%d" % (k, self._counts[k]) for k in sorted(self._counts)] ) return "%d {%s}" % (self._total, details) - def _merge(self, other): + def _merge(self, other) -> None: # pylint: disable=protected-access for key, count in other._counts.items(): self._counts[key] = self._counts.get(key, 0) + count @@ -293,7 +294,7 @@ def _merge(self, other): class Distribution(Metric): """A metric to track simple statistics from a distribution of values.""" - def __init__(self, name): + def __init__(self, name) -> None: super().__init__(name) self._count = 0 # Number of values. self._total = 0.0 # Sum of the values. @@ -301,7 +302,7 @@ def __init__(self, name): self._min = None self._max = None - def add(self, value): + def add(self, value) -> None: """Add a value to the distribution.""" if not _enabled: return @@ -315,11 +316,11 @@ def add(self, value): self._min = min(self._min, value) self._max = max(self._max, value) - def _mean(self): + def _mean(self) -> None: if self._count: return self._total / float(self._count) - def _stdev(self): + def _stdev(self) -> Optional[float]: if self._count: variance = (self._squared * self._count - self._total * self._total) / ( self._count * self._count @@ -330,7 +331,7 @@ def _stdev(self): return 0.0 return math.sqrt(variance) - def _summary(self): + def _summary(self) -> str: return "total=%s, count=%d, min=%s, max=%s, mean=%s, stdev=%s" % ( self._total, self._count, @@ -340,7 +341,7 @@ def _summary(self): self._stdev(), ) - def _merge(self, other): + def _merge(self, other) -> None: # pylint: disable=protected-access if other._count == 0: # Exit early so we don't have to worry about min/max of None. @@ -361,7 +362,7 @@ class Snapshot(Metric): def __init__( self, name, enabled=False, groupby="lineno", nframes=1, count=10 - ): + ) -> None: if enabled and tracemalloc is None: raise RuntimeError("tracemalloc module couldn't be imported") super().__init__(name) @@ -382,15 +383,15 @@ def __init__( # options.memory_snapshot flag set by the --memory-snapshots option) self.enabled = _enabled and enabled - def _start_tracemalloc(self): + def _start_tracemalloc(self) -> None: tracemalloc.start(self.nframes) self.running = True - def _stop_tracemalloc(self): + def _stop_tracemalloc(self) -> None: tracemalloc.stop() self.running = False - def take_snapshot(self, where=""): + def take_snapshot(self, where="") -> None: """Stores a tracemalloc snapshot.""" if not self.enabled: return @@ -409,26 +410,26 @@ def take_snapshot(self, where=""): ) ) - def __enter__(self): + def __enter__(self) -> None: if not self.enabled: return self._start_tracemalloc() self.take_snapshot("__enter__") - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: if not self.running: return self.take_snapshot("__exit__") self._stop_tracemalloc() - def _summary(self): + def _summary(self) -> str: return "\n\n".join(self.snapshots) class MetricsContext: """A context manager that configures metrics and writes their output.""" - def __init__(self, output_path, open_function=open): + def __init__(self, output_path, open_function=open) -> None: """Initialize. Args: @@ -440,12 +441,12 @@ def __init__(self, output_path, open_function=open): self._open_function = open_function self._old_enabled = None # Set in __enter__. - def __enter__(self): + def __enter__(self) -> None: global _enabled self._old_enabled = _enabled _enabled = bool(self._output_path) - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: global _enabled _enabled = self._old_enabled if self._output_path: diff --git a/pytype/module_utils.py b/pytype/module_utils.py index e0044d9b6..e99d34154 100644 --- a/pytype/module_utils.py +++ b/pytype/module_utils.py @@ -30,7 +30,7 @@ def full_path(self): return path_utils.join(self.path, self.target) -def infer_module(filename, pythonpath): +def infer_module(filename, pythonpath) -> Module: """Convert a filename to a module relative to pythonpath. This method tries to deduce the module name from the pythonpath and the @@ -83,11 +83,11 @@ def path_to_module_name(filename): return module_name -def strip_init_suffix(parts: Sequence[str]): +def strip_init_suffix(parts: Sequence[str]) -> Sequence[str]: return parts[:-1] if parts and parts[-1] == "__init__" else parts -def get_absolute_name(prefix, relative_name): +def get_absolute_name(prefix, relative_name) -> str | None: """Joins a dotted-name prefix and a relative name. Args: @@ -148,7 +148,7 @@ def get_relative_name(prefix: str, absolute_name: str) -> str: return name -def get_package_name(module_name, is_package=False): +def get_package_name(module_name, is_package=False) -> str: """Figure out a package name for a module.""" if module_name is None: return "" @@ -158,7 +158,7 @@ def get_package_name(module_name, is_package=False): return ".".join(parts) -def get_all_prefixes(module_name): +def get_all_prefixes(module_name) -> list: """Return all the prefixes of a module name. e.g. x.y.z => x, x.y, x.y.z diff --git a/pytype/output.py b/pytype/output.py index 9a331fe16..7f2d4070b 100644 --- a/pytype/output.py +++ b/pytype/output.py @@ -5,7 +5,7 @@ import enum import logging import re -from typing import cast +from typing import Any, Generator, cast from pytype import utils from pytype.abstract import abstract @@ -24,7 +24,7 @@ from pytype.pytd import pytd_utils from pytype.pytd import visitors -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) # A variable with more bindings than this is treated as a large literal constant @@ -96,12 +96,12 @@ def _get_values(self, node, var, view): else: return var.data - def _is_tuple(self, v, instance): + def _is_tuple(self, v, instance) -> bool: return isinstance(v, abstract.TupleClass) or isinstance( instance, abstract.Tuple ) - def _make_decorator(self, name, alias): + def _make_decorator(self, name, alias) -> pytd.Alias: # If decorators are output as aliases to NamedTypes, they will be converted # to Functions and fail a verification step if those functions have type # parameters. Since we just want the function name, and since we have a @@ -447,7 +447,7 @@ def value_to_pytd_type(self, node, v, seen, view): else: raise NotImplementedError(v.__class__.__name__) - def signature_to_callable(self, sig): + def signature_to_callable(self, sig) -> abstract.ParameterizedClass: """Converts a function.Signature object into a callable object. Args: @@ -557,7 +557,9 @@ def value_to_pytd_def(self, node, v, name): else: raise NotImplementedError(v.__class__.__name__) - def _ordered_attrs_to_instance_types(self, node, attr_metadata, annots): + def _ordered_attrs_to_instance_types( + self, node, attr_metadata, annots + ) -> Generator[tuple[Any, Any], Any, None]: """Get instance types for ordered attrs in the metadata.""" attrs = attr_metadata.get("attr_order", []) if not annots or not attrs: @@ -579,7 +581,9 @@ def _ordered_attrs_to_instance_types(self, node, attr_metadata, annots): typ = typ and typ.to_pytd_type_of_instance(node) yield a.name, typ - def annotations_to_instance_types(self, node, annots): + def annotations_to_instance_types( + self, node, annots + ) -> Generator[tuple[Any, Any], Any, None]: """Get instance types for annotations not present in the members map.""" if annots: for name, local in annots.annotated_locals.items(): @@ -605,7 +609,7 @@ def _function_call_to_return_type(self, node, v, seen_return, num_returns): def _function_call_combination_to_signature( self, func, call_combination, num_combinations - ): + ) -> pytd.Signature: node_after, combination, return_value = call_combination params = [] for i, (name, kind, optional) in enumerate(func.get_parameters()): @@ -662,7 +666,7 @@ def _function_call_combination_to_signature( template=(), ) - def _function_to_def(self, node, v, function_name): + def _function_to_def(self, node, v, function_name) -> pytd.Function: """Convert an InterpreterFunction to a PyTD definition.""" signatures = [] for func in v.signature_functions(): @@ -683,7 +687,7 @@ def _function_to_def(self, node, v, function_name): decorators=decorators, ) - def _simple_func_to_def(self, node, v, name): + def _simple_func_to_def(self, node, v, name) -> pytd.Function: """Convert a SimpleFunction to a PyTD definition.""" sig = v.signature @@ -738,7 +742,9 @@ def get_parameter(p, kind): ) return pytd.Function(name, (pytd_sig,), pytd.MethodKind.METHOD) - def _function_to_return_types(self, node, fvar, allowed_type_params=()): + def _function_to_return_types( + self, node, fvar, allowed_type_params=() + ) -> list: """Convert a function variable to a list of PyTD return types.""" options = fvar.FilteredData(self.ctx.exitpoint, strict=False) if not all(isinstance(o, abstract.Function) for o in options): @@ -774,7 +780,7 @@ def _is_instance(self, value, cls_name): isinstance(value, abstract.Instance) and value.cls.full_name == cls_name ) - def _class_to_def(self, node, v, class_name): + def _class_to_def(self, node, v, class_name) -> pytd.Class: """Convert an InterpreterClass to a PyTD definition.""" self._scopes.append(class_name) methods = {} @@ -1113,7 +1119,7 @@ def add_attributes_from(instance): self._scopes.pop() return cls - def _type_variable_to_def(self, node, v, name): + def _type_variable_to_def(self, node, v, name) -> pytd.TypeParameter: constraints = tuple(c.to_pytd_type_of_instance(node) for c in v.constraints) bound = v.bound and v.bound.to_pytd_type_of_instance(node) if isinstance(v, abstract.TypeParameter): @@ -1123,7 +1129,7 @@ def _type_variable_to_def(self, node, v, name): else: assert False, f"Unexpected type variable type: {type(v)}" - def _typed_dict_to_def(self, node, v, name): + def _typed_dict_to_def(self, node, v, name) -> pytd.Class: keywords = [] if not v.props.total: keywords.append(("total", pytd.Literal(False))) diff --git a/pytype/overlays/abc_overlay.py b/pytype/overlays/abc_overlay.py index e948197e9..6b839730a 100644 --- a/pytype/overlays/abc_overlay.py +++ b/pytype/overlays/abc_overlay.py @@ -1,9 +1,20 @@ """Implementation of special members of Python's abc library.""" +from typing import Any, TypeVar from pytype.abstract import abstract from pytype.overlays import overlay from pytype.overlays import special_builtins +_T0 = TypeVar("_T0") +_TAbstractClassMethod = TypeVar( + "_TAbstractClassMethod", bound="AbstractClassMethod" +) +_TAbstractMethod = TypeVar("_TAbstractMethod", bound="AbstractMethod") +_TAbstractProperty = TypeVar("_TAbstractProperty", bound="AbstractProperty") +_TAbstractStaticMethod = TypeVar( + "_TAbstractStaticMethod", bound="AbstractStaticMethod" +) + def _set_abstract(args, argname): if args.posargs: @@ -19,7 +30,7 @@ def _set_abstract(args, argname): class ABCOverlay(overlay.Overlay): """A custom overlay for the 'abc' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { "abstractclassmethod": AbstractClassMethod.make, "abstractmethod": AbstractMethod.make, @@ -37,10 +48,12 @@ class AbstractClassMethod(special_builtins.ClassMethod): """Implements abc.abstractclassmethod.""" @classmethod - def make(cls, ctx, module): + def make( + cls: type[_TAbstractClassMethod], ctx, module + ) -> _TAbstractClassMethod: return super().make_alias("abstractclassmethod", ctx, module) - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: _ = _set_abstract(args, "callable") return super().call(node, func, args, alias_map) @@ -49,10 +62,10 @@ class AbstractMethod(abstract.PyTDFunction): """Implements the @abc.abstractmethod decorator.""" @classmethod - def make(cls, ctx, module): + def make(cls: type[_TAbstractMethod], ctx, module) -> _TAbstractMethod: return super().make("abstractmethod", ctx, module) - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: """Marks that the given function is abstract.""" del func, alias_map # unused self.match_args(node, args) @@ -63,10 +76,10 @@ class AbstractProperty(special_builtins.Property): """Implements the @abc.abstractproperty decorator.""" @classmethod - def make(cls, ctx, module): + def make(cls: type[_TAbstractProperty], ctx, module) -> _TAbstractProperty: return super().make_alias("abstractproperty", ctx, module) - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: property_args = self._get_args(args) for v in property_args.values(): for b in v.bindings: @@ -86,9 +99,11 @@ class AbstractStaticMethod(special_builtins.StaticMethod): """Implements abc.abstractstaticmethod.""" @classmethod - def make(cls, ctx, module): + def make( + cls: type[_TAbstractStaticMethod], ctx, module + ) -> _TAbstractStaticMethod: return super().make_alias("abstractstaticmethod", ctx, module) - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: _ = _set_abstract(args, "callable") return super().call(node, func, args, alias_map) diff --git a/pytype/overlays/asyncio_types_overlay.py b/pytype/overlays/asyncio_types_overlay.py index 3921c00de..b5cae2250 100644 --- a/pytype/overlays/asyncio_types_overlay.py +++ b/pytype/overlays/asyncio_types_overlay.py @@ -1,14 +1,20 @@ """Implementation of special members of types and asyncio module.""" +from typing import Any, TypeVar from pytype.abstract import abstract from pytype.abstract import abstract_utils from pytype.overlays import overlay +_T0 = TypeVar("_T0") +_TCoroutineDecorator = TypeVar( + "_TCoroutineDecorator", bound="CoroutineDecorator" +) + class TypesOverlay(overlay.Overlay): """A custom overlay for the 'types' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = {"coroutine": CoroutineDecorator.make} ast = ctx.loader.import_name("types") super().__init__(ctx, "types", member_map, ast) @@ -17,7 +23,7 @@ def __init__(self, ctx): class AsyncioOverlay(overlay.Overlay): """A custom overlay for the 'asyncio' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = {} if ctx.python_version <= (3, 10): member_map["coroutine"] = CoroutineDecorator.make @@ -29,10 +35,12 @@ class CoroutineDecorator(abstract.PyTDFunction): """Implements the @types.coroutine and @asyncio.coroutine decorator.""" @classmethod - def make(cls, ctx, module): + def make( + cls: type[_TCoroutineDecorator], ctx, module + ) -> _TCoroutineDecorator: return super().make("coroutine", ctx, module) - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: """Marks the function as a generator-based coroutine.""" del func, alias_map # unused self.match_args(node, args) diff --git a/pytype/overlays/attr_overlay.py b/pytype/overlays/attr_overlay.py index f2aabab02..a965cf650 100644 --- a/pytype/overlays/attr_overlay.py +++ b/pytype/overlays/attr_overlay.py @@ -14,11 +14,20 @@ from pytype.overlays import overlay from pytype.overlays import overlay_utils -log = logging.getLogger(__name__) +_T0 = TypeVar("_T0") +_T1 = TypeVar("_T1") +_TAttrib = TypeVar("_TAttrib", bound="Attrib") +_TAttribInstance = TypeVar("_TAttribInstance", bound="AttribInstance") +_TAttrs = TypeVar("_TAttrs", bound="Attrs") +_TAttrsNextGenDefine = TypeVar( + "_TAttrsNextGenDefine", bound="AttrsNextGenDefine" +) + +log: logging.Logger = logging.getLogger(__name__) # type aliases for convenience -Param = overlay_utils.Param -Attribute = classgen.Attribute +Param: type[overlay_utils.Param] = overlay_utils.Param +Attribute: type[classgen.Attribute] = classgen.Attribute _TBaseValue = TypeVar("_TBaseValue", bound=abstract.BaseValue) @@ -36,7 +45,7 @@ class _AttrOverlayBase(overlay.Overlay): _MODULE_NAME: str - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { # Attr's next-gen APIs # See https://www.attrs.org/en/stable/api.html#next-gen @@ -61,7 +70,7 @@ class AttrOverlay(_AttrOverlayBase): _MODULE_NAME = "attr" - def __init__(self, ctx): + def __init__(self, ctx) -> None: super().__init__(ctx) self._member_map.update({ "attrs": Attrs.make, @@ -88,7 +97,7 @@ class _NoChange: # A unique sentinel value to signal not to write anything, not even the # original value. -_NO_CHANGE = _NoChange() +_NO_CHANGE: _NoChange = _NoChange() class AttrsBase(classgen.Decorator): @@ -108,7 +117,7 @@ def _handle_auto_attribs( # writing even the same value might have a side effect (changing ordering). return _NO_CHANGE, _ordering_for_auto_attrib(auto_attribs) - def decorate(self, node, cls): + def decorate(self, node, cls) -> None: """Processes the attrib members of a class.""" # Collect classvars to convert them to attrs. new_auto_attribs, ordering = self._handle_auto_attribs( @@ -252,7 +261,7 @@ def decorate(self, node, cls): # Fix up type parameters in methods added by the decorator. cls.update_method_type_params() - def to_metadata(self): + def to_metadata(self) -> dict[str, Any]: # For simplicity, we give all attrs decorators with the same behavior as # attr.s the same tag. args = self._current_args or self.DEFAULT_ARGS @@ -268,11 +277,11 @@ class Attrs(AttrsBase): """Implements the @attr.s decorator.""" @classmethod - def make(cls, ctx, module="attr"): + def make(cls: type[_TAttrs], ctx, module="attr") -> _TAttrs: return super().make("s", ctx, module) @classmethod - def make_dataclass(cls, ctx, module): + def make_dataclass(cls: type[_TAttrs], ctx, module) -> _TAttrs: ret = super().make("s", ctx, module) ret.partial_args["auto_attribs"] = True return ret @@ -307,10 +316,14 @@ class AttrsNextGenDefine(AttrsBase): } @classmethod - def make(cls, ctx, module): + def make( + cls: type[_TAttrsNextGenDefine], ctx, module + ) -> _TAttrsNextGenDefine: return super().make("define", ctx, module) - def _handle_auto_attribs(self, auto_attribs, local_ops, cls_name): + def _handle_auto_attribs( + self, auto_attribs, local_ops, cls_name + ) -> tuple[bool | _NoChange | None, Any]: if auto_attribs is not None: return super()._handle_auto_attribs(auto_attribs, local_ops, cls_name) is_annotated = {} @@ -329,7 +342,9 @@ def _handle_auto_attribs(self, auto_attribs, local_ops, cls_name): class AttribInstance(abstract.SimpleValue, mixin.HasSlots): """Return value of an attr.ib() call.""" - def __init__(self, ctx, typ, type_source, init, init_type, kw_only, default): + def __init__( + self, ctx, typ, type_source, init, init_type, kw_only, default + ) -> None: super().__init__("attrib", ctx) mixin.HasSlots.init_mixin(self) self.typ = typ @@ -343,7 +358,7 @@ def __init__(self, ctx, typ, type_source, init, init_type, kw_only, default): self.set_native_slot("default", self.default_slot) self.set_native_slot("validator", self.validator_slot) - def default_slot(self, node, default): + def default_slot(self, node, default: _T1) -> tuple[Any, _T1]: # If the default is a method, call it and use its return type. fn = default.data[0] # TODO(mdemello): it is not clear what to use for self in fn_args; using @@ -373,10 +388,10 @@ def default_slot(self, node, default): # Return the original decorated method so we don't lose it. return node, default - def validator_slot(self, node, validator): + def validator_slot(self, node: _T0, validator: _T1) -> tuple[_T0, _T1]: return node, validator - def to_metadata(self): + def to_metadata(self) -> dict[str, Any]: type_source = self.type_source and self.type_source.name return { "tag": "attr.ib", @@ -387,7 +402,9 @@ def to_metadata(self): } @classmethod - def from_metadata(cls, ctx, node, typ, metadata): + def from_metadata( + cls: type[_TAttribInstance], ctx, node, typ, metadata + ) -> _TAttribInstance: init = metadata["init"] kw_only = metadata["kw_only"] type_source = metadata["type_source"] @@ -402,7 +419,7 @@ class Attrib(classgen.FieldConstructor): """Implements attr.ib/attrs.field.""" @classmethod - def make(cls, ctx, module): + def make(cls: type[_TAttrib], ctx, module) -> _TAttrib: return super().make("ib" if module == "attr" else "field", ctx, module) def _match_and_discard_args(self, node, funcb, args): @@ -426,7 +443,7 @@ def _match_and_discard_args(self, node, funcb, args): args = args.replace_namedarg("default", self.ctx.new_unsolvable(node)) return args - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: """Returns a type corresponding to an attr.""" args = args.simplify(node, self.ctx) args = self._match_and_discard_args(node, func, args) @@ -507,7 +524,9 @@ def valid_arity(sig): raise error_types.WrongArgTypes(self.sig, args, self.ctx, bad_param) return valid_sigs[0] - def _call_converter_function(self, node, converter_var, args): + def _call_converter_function( + self, node, converter_var, args + ) -> tuple[Any, Any]: """Run converter and return the input and return types.""" binding = converter_var.bindings[0] fn = binding.data @@ -540,7 +559,7 @@ def _get_converter_types(self, node, args): else: return None, None - def _get_default_var(self, node, args): + def _get_default_var(self, node, args) -> tuple[Any, Any]: if "default" in args.namedargs and "factory" in args.namedargs: # attr.ib(factory=x) is syntactic sugar for attr.ib(default=Factory(x)). raise error_types.DuplicateKeyword(self.sig, args, self.ctx, "default") @@ -568,7 +587,7 @@ def _ordering_for_auto_attrib(auto_attrib): ) -def is_attrib(var): +def is_attrib(var: _T0) -> bool | _T0: return var and isinstance(var.data[0], AttribInstance) diff --git a/pytype/overlays/chex_overlay.py b/pytype/overlays/chex_overlay.py index 678239385..7993dc31b 100644 --- a/pytype/overlays/chex_overlay.py +++ b/pytype/overlays/chex_overlay.py @@ -7,6 +7,8 @@ * Chex dataclasses have replace, from_tuple, and to_tuple methods. """ +from typing import Any + from pytype.abstract import abstract from pytype.overlays import classgen from pytype.overlays import dataclass_overlay @@ -17,7 +19,7 @@ class ChexOverlay(overlay.Overlay): - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { "dataclass": Dataclass.make, } @@ -28,17 +30,17 @@ def __init__(self, ctx): class Dataclass(dataclass_overlay.Dataclass): """Implements the @dataclass decorator.""" - DEFAULT_ARGS = { + DEFAULT_ARGS: dict[str, Any] = { **dataclass_overlay.Dataclass.DEFAULT_ARGS, "mappable_dataclass": True, } - def _add_replace_method(self, node, cls): + def _add_replace_method(self, node, cls) -> None: cls.members["replace"] = classgen.make_replace_method( self.ctx, node, cls, kwargs_name="changes" ) - def _add_from_tuple_method(self, node, cls): + def _add_from_tuple_method(self, node, cls) -> None: # from_tuple is discouraged anyway, so we provide only bare-bones types. cls.members["from_tuple"] = overlay_utils.make_method( ctx=self.ctx, @@ -49,7 +51,7 @@ def _add_from_tuple_method(self, node, cls): kind=pytd.MethodKind.STATICMETHOD, ) - def _add_to_tuple_method(self, node, cls): + def _add_to_tuple_method(self, node, cls) -> None: # to_tuple is discouraged anyway, so we provide only bare-bones types. cls.members["to_tuple"] = overlay_utils.make_method( ctx=self.ctx, @@ -58,7 +60,7 @@ def _add_to_tuple_method(self, node, cls): return_type=self.ctx.convert.tuple_type, ) - def _add_mapping_methods(self, node, cls): + def _add_mapping_methods(self, node, cls) -> None: if "__getitem__" not in cls.members: cls.members["__getitem__"] = overlay_utils.make_method( ctx=self.ctx, @@ -82,7 +84,7 @@ def _add_mapping_methods(self, node, cls): return_type=self.ctx.convert.int_type, ) - def decorate(self, node, cls): + def decorate(self, node, cls) -> None: super().decorate(node, cls) if not isinstance(cls, abstract.InterpreterClass): return diff --git a/pytype/overlays/classgen.py b/pytype/overlays/classgen.py index 8bd1afd08..57c936182 100644 --- a/pytype/overlays/classgen.py +++ b/pytype/overlays/classgen.py @@ -7,7 +7,7 @@ import collections import dataclasses import logging -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeVar from pytype.abstract import abstract from pytype.abstract import abstract_utils @@ -15,14 +15,17 @@ from pytype.overlays import overlay_utils from pytype.overlays import special_builtins +_T0 = TypeVar("_T0") +_TClassProperties = TypeVar("_TClassProperties", bound="ClassProperties") -log = logging.getLogger(__name__) + +log: logging.Logger = logging.getLogger(__name__) # type aliases for convenience -Param = overlay_utils.Param -Attribute = class_mixin.Attribute -AttributeKinds = class_mixin.AttributeKinds +Param: type[overlay_utils.Param] = overlay_utils.Param +Attribute: type[class_mixin.Attribute] = class_mixin.Attribute +AttributeKinds: type[class_mixin.AttributeKinds] = class_mixin.AttributeKinds # Probably should make this an enum.Enum at some point. @@ -37,7 +40,7 @@ class Ordering: # the locals will be [(x, Instance(float)), (y, Instance(str))]. Note that # unannotated variables will be skipped, and the values of later annotations # take precedence over earlier ones. - FIRST_ANNOTATE = object() + FIRST_ANNOTATE: Any = object() # Order by each variable's last definition. So for # class Foo: # x = 0 @@ -45,7 +48,7 @@ class Ordering: # x = 4.2 # the locals will be [(y, Instance(str)), (x, Instance(float))]. Note that # variables without assignments will be skipped. - LAST_ASSIGN = object() + LAST_ASSIGN: Any = object() class Decorator(abstract.PyTDFunction, metaclass=abc.ABCMeta): @@ -59,7 +62,7 @@ class Decorator(abstract.PyTDFunction, metaclass=abc.ABCMeta): "auto_attribs": False, } - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # Decorator.call() is invoked first with args, then with the class to # decorate, so we need to first store the args and then associate them to @@ -74,12 +77,12 @@ def __init__(self, *args, **kwargs): def decorate(self, node, cls): """Apply the decorator to cls.""" - def get_initial_args(self): + def get_initial_args(self) -> dict[str, Any]: ret = self.DEFAULT_ARGS.copy() ret.update(self.partial_args) return ret - def update_kwargs(self, args): + def update_kwargs(self, args) -> None: """Update current_args with the Args passed to the decorator.""" self._current_args = self.get_initial_args() for k, v in args.namedargs.items(): @@ -91,7 +94,7 @@ def update_kwargs(self, args): self.ctx.vm.frames, f"Non-constant argument to decorator: {k!r}" ) - def set_current_args(self, kwargs): + def set_current_args(self, kwargs) -> None: """Set current_args when constructing a class directly.""" self._current_args = self.get_initial_args() self._current_args.update(kwargs) @@ -123,7 +126,7 @@ def make_init(self, node, cls, attrs, init_method_name="__init__"): self.ctx, node, init_method_name, pos_params, 0, kwonly_params ) - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: """Construct a decorator, and call it on the class.""" args = args.simplify(node, self.ctx) self.match_args(node, args) @@ -188,14 +191,14 @@ def get_kwarg(self, args, name, default): self.ctx.vm.frames, f"Non-constant argument {name!r}" ) - def get_positional_names(self): + def get_positional_names(self) -> list[None]: # TODO(mdemello): We currently assume all field constructors are called with # namedargs, which has worked in practice but is not required by the attrs # or dataclasses apis. return [] -def is_method(var): +def is_method(var) -> bool: if var is None: return False return isinstance( @@ -213,7 +216,7 @@ def is_dunder(name): return name.startswith("__") and name.endswith("__") -def add_member(node, cls, name, typ): +def add_member(node, cls, name, typ) -> None: if typ.formal: # If typ contains a type parameter, we mark it as empty so that instances # will use __annotations__ to fill in concrete type parameter values. @@ -254,7 +257,9 @@ def is_relevant_class_local( return True -def get_class_locals(cls_name: str, allow_methods: bool, ordering, ctx): +def get_class_locals( + cls_name: str, allow_methods: bool, ordering, ctx +) -> collections.OrderedDict: """Gets a dictionary of the class's local variables. Args: @@ -342,7 +347,9 @@ class ClassProperties: bases: list[Any] @classmethod - def from_field_names(cls, name, field_names, ctx): + def from_field_names( + cls: type[_TClassProperties], name, field_names, ctx + ) -> _TClassProperties: """Make a ClassProperties from field names with no types.""" fields = [Field(n, ctx.convert.unsolvable, None) for n in field_names] return cls(name, fields, []) diff --git a/pytype/overlays/collections_overlay.py b/pytype/overlays/collections_overlay.py index 562652da5..27630ca14 100644 --- a/pytype/overlays/collections_overlay.py +++ b/pytype/overlays/collections_overlay.py @@ -1,5 +1,6 @@ """Implementation of types from Python 2's collections library.""" +from collections.abc import Callable from pytype.overlays import named_tuple from pytype.overlays import overlay from pytype.overlays import typing_overlay @@ -8,7 +9,7 @@ class CollectionsOverlay(overlay.Overlay): """A custom overlay for the 'collections' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: """Initializes the CollectionsOverlay. This function loads the AST for the collections module, which is used to @@ -24,7 +25,7 @@ def __init__(self, ctx): super().__init__(ctx, "collections", member_map, ast) -collections_overlay = { +collections_overlay: dict[str, Callable] = { "namedtuple": named_tuple.CollectionsNamedTupleBuilder.make, } @@ -32,7 +33,7 @@ def __init__(self, ctx): class ABCOverlay(typing_overlay.Redirect): """A custom overlay for the 'collections.abc' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: # collections.abc.Set equates to typing.AbstractSet rather than typing.Set. # This is the only such mismatch. aliases = {"Set": "typing.AbstractSet"} diff --git a/pytype/overlays/dataclass_overlay.py b/pytype/overlays/dataclass_overlay.py index 49c0df868..023ccab27 100644 --- a/pytype/overlays/dataclass_overlay.py +++ b/pytype/overlays/dataclass_overlay.py @@ -4,7 +4,9 @@ # - Raise an error if we see a duplicate annotation, even though python allows # it, since there is no good reason to do that. +from collections import OrderedDict import logging +from typing import Any, TypeVar from pytype.abstract import abstract from pytype.abstract import abstract_utils @@ -13,13 +15,18 @@ from pytype.overlays import classgen from pytype.overlays import overlay -log = logging.getLogger(__name__) +_T0 = TypeVar("_T0") +_TDataclass = TypeVar("_TDataclass", bound="Dataclass") +_TFieldFunction = TypeVar("_TFieldFunction", bound="FieldFunction") +_TReplace = TypeVar("_TReplace", bound="Replace") + +log: logging.Logger = logging.getLogger(__name__) class DataclassOverlay(overlay.Overlay): """A custom overlay for the 'dataclasses' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { "dataclass": Dataclass.make, "field": FieldFunction.make, @@ -33,7 +40,7 @@ class Dataclass(classgen.Decorator): """Implements the @dataclass decorator.""" @classmethod - def make(cls, ctx, module="dataclasses"): + def make(cls: type[_TDataclass], ctx, module="dataclasses") -> _TDataclass: return super().make("dataclass", ctx, module) @classmethod @@ -64,7 +71,7 @@ def _handle_initvar(self, node, cls, name, typ, orig): classgen.add_member(node, cls, name, initvar) return initvar - def get_class_locals(self, node, cls): + def get_class_locals(self, node, cls) -> OrderedDict: del node return classgen.get_class_locals( cls.name, @@ -73,7 +80,7 @@ def get_class_locals(self, node, cls): ctx=self.ctx, ) - def decorate(self, node, cls): + def decorate(self, node, cls) -> None: """Processes class members.""" # Collect classvars to convert them to attrs. @dataclass collects vars with @@ -193,7 +200,7 @@ def decorate(self, node, cls): class FieldInstance(abstract.SimpleValue): """Return value of a field() call.""" - def __init__(self, ctx, init, default, kw_only): + def __init__(self, ctx, init, default, kw_only) -> None: super().__init__("field", ctx) self.init = init self.default = default @@ -205,10 +212,10 @@ class FieldFunction(classgen.FieldConstructor): """Implements dataclasses.field.""" @classmethod - def make(cls, ctx, module): + def make(cls: type[_TFieldFunction], ctx, module) -> _TFieldFunction: return super().make("field", ctx, module) - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: """Returns a type corresponding to a field.""" args = args.simplify(node, self.ctx) self.match_args(node, args) @@ -218,7 +225,7 @@ def call(self, node, func, args, alias_map=None): typ = FieldInstance(self.ctx, init, default_var, kw_only).to_variable(node) return node, typ - def _get_default_var(self, node, args): + def _get_default_var(self, node, args) -> tuple[Any, Any]: if "default" in args.namedargs and "default_factory" in args.namedargs: # The pyi signatures should prevent this; check left in for safety. raise error_types.DuplicateKeyword( @@ -237,7 +244,7 @@ def _get_default_var(self, node, args): return node, default_var -def is_field(var): +def is_field(var: _T0) -> bool | _T0: return var and isinstance(var.data[0], FieldInstance) @@ -255,7 +262,7 @@ class Replace(abstract.PyTDFunction): """Implements dataclasses.replace.""" @classmethod - def make(cls, ctx, module="dataclasses"): + def make(cls: type[_TReplace], ctx, module="dataclasses") -> _TReplace: return super().make("replace", ctx, module) def _match_args_sequentially(self, node, args, alias_map, match_all_views): diff --git a/pytype/overlays/enum_overlay.py b/pytype/overlays/enum_overlay.py index 65ea340bf..977051313 100644 --- a/pytype/overlays/enum_overlay.py +++ b/pytype/overlays/enum_overlay.py @@ -25,6 +25,7 @@ import collections import contextlib import logging +from typing import Any, TypeVar from pytype.abstract import abstract from pytype.abstract import abstract_utils @@ -39,11 +40,14 @@ from pytype.pytd import pytd_utils from pytype.typegraph import cfg -log = logging.getLogger(__name__) + +_T0 = TypeVar("_T0") + +log: logging.Logger = logging.getLogger(__name__) # These members have been added in Python 3.11 and are not yet supported. -_unsupported = ( +_unsupported: tuple[str, str, str, str, str, str, str, str, str] = ( "ReprEnum", "EnumCheck", "FlagBoundary", @@ -59,7 +63,7 @@ class EnumOverlay(overlay.Overlay): """An overlay for the enum std lib module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { "Enum": overlay.add_name("Enum", EnumBuilder), "EnumMeta": EnumMeta, @@ -78,7 +82,7 @@ def __init__(self, ctx): class EnumBuilder(abstract.PyTDClass): """Overlays enum.Enum.""" - def __init__(self, name, ctx, module): + def __init__(self, name, ctx, module) -> None: super().__init__(name, ctx.loader.lookup_pytd(module, name), ctx) def make_class(self, node, props): @@ -205,7 +209,7 @@ def call(self, node, func, args, alias_map=None): class EnumInstance(abstract.InterpreterClass): """A wrapper for classes that subclass enum.Enum.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # These are set by EnumMetaInit.setup_interpreterclass. self.member_type = None @@ -246,7 +250,7 @@ def instantiate(self, node, container=None): instance.members[attr_name] = attr_type.instantiate(node) return instance.to_variable(node) - def is_empty_enum(self): + def is_empty_enum(self) -> bool: for member in self.members.values(): for b in member.data: if b.cls == self: @@ -272,7 +276,7 @@ class EnumCmpEQ(abstract.SimpleFunction): # comparing the members' names. However, this causes issues when enums are # used in an if statement; see the bug for examples. - def __init__(self, ctx): + def __init__(self, ctx) -> None: sig = function.Signature( name="__eq__", param_names=("self", "other"), @@ -287,7 +291,7 @@ def __init__(self, ctx): ) super().__init__(sig, ctx) - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: _, argmap = self.match_and_map_args(node, args, alias_map) this_var = argmap["self"] other_var = argmap["other"] @@ -308,7 +312,7 @@ class EnumMeta(abstract.PyTDClass): enum behavior: EnumMetaInit for modifying enum classes, for example. """ - def __init__(self, ctx, module): + def __init__(self, ctx, module) -> None: pytd_cls = ctx.loader.lookup_pytd(module, "EnumMeta") super().__init__("EnumMeta", pytd_cls, ctx) init = EnumMetaInit(ctx) @@ -326,7 +330,7 @@ class EnumMetaInit(abstract.SimpleFunction): handling and set up the Enum classes correctly. """ - def __init__(self, ctx): + def __init__(self, ctx) -> None: sig = function.Signature( name="__init__", param_names=("cls", "name", "bases", "namespace"), @@ -521,7 +525,7 @@ def _is_orig_auto(self, orig): and data.cls.full_name == "enum.auto" ) - def _call_generate_next_value(self, node, cls, name): + def _call_generate_next_value(self, node, cls, name) -> tuple[Any, Any]: node, method = self.ctx.attribute_handler.get_attribute( node, cls, "_generate_next_value_", cls.to_binding(node) ) @@ -563,7 +567,7 @@ def _value_to_starargs(self, node, value_var, base_type): args = self.ctx.convert.build_tuple(node, [args]) return args - def _mark_dynamic_enum(self, cls): + def _mark_dynamic_enum(self, cls) -> None: # Checks if the enum should be marked as having dynamic attributes. # Of course, if it's already marked dynamic, don't accidentally unmark it. if cls.maybe_missing_members: @@ -711,7 +715,7 @@ def _setup_interpreterclass(self, node, cls): cls.members["_generate_next_value_"] = new_gnv return node - def _setup_pytdclass(self, node, cls): + def _setup_pytdclass(self, node: _T0, cls) -> _T0: # Only constants need to be transformed. We assume that enums in type # stubs are fully realized, i.e. there are no auto() calls and the members # already have values of the base type. @@ -783,7 +787,7 @@ def _setup_pytdclass(self, node, cls): cls.members["__new__"] = self._make_new(node, member_type, cls) return node - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: # Use super.call to check args and get a return value. node, ret = super().call(node, func, args, alias_map) argmap = self._map_args(node, args) @@ -817,7 +821,7 @@ def call(self, node, func, args, alias_map=None): class EnumMetaGetItem(abstract.SimpleFunction): """Implements the functionality of __getitem__ for enums.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: sig = function.Signature( name="__getitem__", param_names=("cls", "name"), @@ -840,7 +844,7 @@ def _get_member_by_name( enum.load_lazy_attribute(name) return enum.members[name] - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: _, argmap = self.match_and_map_args(node, args, alias_map) cls_var = argmap["cls"] name_var = argmap["name"] diff --git a/pytype/overlays/fiddle_overlay.py b/pytype/overlays/fiddle_overlay.py index 17006d2ff..c01d86eda 100644 --- a/pytype/overlays/fiddle_overlay.py +++ b/pytype/overlays/fiddle_overlay.py @@ -1,7 +1,7 @@ """Implementation of types from the fiddle library.""" import re -from typing import Any +from typing import Any, TypeVar from pytype.abstract import abstract from pytype.abstract import abstract_utils @@ -12,9 +12,12 @@ from pytype.pytd import pytd +_TBuildableType = TypeVar("_TBuildableType", bound="BuildableType") + + # Type aliases so we aren't importing stuff purely for annotations -Node = Any -Variable = Any +Node: Any = Any +Variable: Any = Any # Cache instances, so that we don't generate two different classes when @@ -24,7 +27,7 @@ _INSTANCE_CACHE: dict[tuple[Node, abstract.Class, str], abstract.Instance] = {} -_CLASS_ALIASES = { +_CLASS_ALIASES: dict[str, str] = { "Config": "Config", "PaxConfig": "Config", "Partial": "Partial", @@ -35,7 +38,7 @@ class FiddleOverlay(overlay.Overlay): """A custom overlay for the 'fiddle' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: """Initializes the FiddleOverlay. This function loads the AST for the fiddle module, which is used to @@ -60,7 +63,7 @@ def __init__(self, ctx): class BuildableBuilder(abstract.PyTDClass, mixin.HasSlots): """Factory for creating fiddle.Config classes.""" - def __init__(self, name, ctx, module): + def __init__(self, name, ctx, module) -> None: pytd_cls = ctx.loader.lookup_pytd(module, name) # fiddle.Config/Partial loads as a LateType, convert to pytd.Class if isinstance(pytd_cls, pytd.Constant): @@ -72,10 +75,10 @@ def __init__(self, name, ctx, module): self.fiddle_type_name = _CLASS_ALIASES[name] self.module = module - def __repr__(self): + def __repr__(self) -> str: return f"FiddleBuildableBuilder[{self.name}]" - def _match_pytd_init(self, node, init_var, args): + def _match_pytd_init(self, node, init_var, args) -> None: init = init_var.data[0] old_pytd_sigs = [] for signature in init.signatures: @@ -90,7 +93,7 @@ def _match_pytd_init(self, node, init_var, args): for signature, old_pytd_sig in zip(init.signatures, old_pytd_sigs): signature.pytd_sig = old_pytd_sig - def _match_interpreter_init(self, node, init_var, args): + def _match_interpreter_init(self, node, init_var, args) -> None: # Buildables support partial initialization, so give every parameter a # default when matching __init__. init = init_var.data[0] @@ -109,7 +112,7 @@ def _match_interpreter_init(self, node, init_var, args): else: del init.signature.defaults[k] - def _make_init_args(self, node, underlying, args, kwargs): + def _make_init_args(self, node, underlying, args, kwargs) -> function.Args: """Unwrap Config instances for arg matching.""" def unwrap(arg_var): @@ -137,7 +140,7 @@ def unwrap(arg_var): new_kwargs = {k: unwrap(arg) for k, arg in kwargs.items()} return function.Args(posargs=new_args, namedargs=new_kwargs) - def _check_init_args(self, node, underlying, args, kwargs): + def _check_init_args(self, node, underlying, args, kwargs) -> None: # Configs can be initialized either with no args, e.g. Config(Class) or with # initial values, e.g. Config(Class, x=10, y=20). We need to check here that # the extra args match the underlying __init__ signature. @@ -183,7 +186,9 @@ def get_own_new(self, node, value) -> tuple[Node, Variable]: class BuildableType(abstract.ParameterizedClass): """Base generic class for fiddle.Config and fiddle.Partial.""" - def __init__(self, base_cls, underlying, ctx, template=None, module="fiddle"): + def __init__( + self, base_cls, underlying, ctx, template=None, module="fiddle" + ) -> None: if isinstance(underlying, abstract.FUNCTION_TYPES): # We don't support functions for now, but falling back to Any here gets us # as much of the functionality as possible. @@ -196,15 +201,22 @@ def __init__(self, base_cls, underlying, ctx, template=None, module="fiddle"): # Classes and TypeVars formal_type_parameters = {abstract_utils.T: underlying} - super().__init__(base_cls, formal_type_parameters, ctx, template) # pytype: disable=wrong-arg-types + super().__init__( + base_cls, formal_type_parameters, ctx, template + ) # pytype: disable=wrong-arg-types self.fiddle_type_name = base_cls.fiddle_type_name self.underlying = underlying self.module = module @classmethod def make( - cls, fiddle_type_name, underlying, ctx, template=None, module="fiddle" - ): + cls: type[_TBuildableType], + fiddle_type_name, + underlying, + ctx, + template=None, + module="fiddle", + ) -> _TBuildableType: base_cls = BuildableBuilder(fiddle_type_name, ctx, module) return cls(base_cls, underlying, ctx, template, module) @@ -226,14 +238,14 @@ def instantiate(self, node, container=None): ) return ret.to_variable(node) - def __repr__(self): + def __repr__(self) -> str: return f"{self.fiddle_type_name}Type[{self.underlying}]" class Buildable(abstract.Instance, mixin.HasSlots): """Base class for Config and Partial instances.""" - def __init__(self, fiddle_type_name, cls, ctx, container=None): + def __init__(self, fiddle_type_name, cls, ctx, container=None) -> None: super().__init__(cls, ctx, container) self.fiddle_type_name = fiddle_type_name self.underlying = None @@ -252,18 +264,18 @@ def getitem_slot(self, node, slice_var) -> tuple[Node, abstract.Instance]: class Config(Buildable): """An instantiation of a fiddle.Config with a particular template.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__("Config", *args, **kwargs) class Partial(Buildable): """An instantiation of a fiddle.Partial with a particular template.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__("Partial", *args, **kwargs) -def _convert_type(typ, subst, ctx): +def _convert_type(typ, subst, ctx) -> abstract.Union: """Helper function for recursive type conversion of fields.""" if isinstance(typ, abstract.TypeParameter) and typ.name in subst: # TODO(mdemello): Handle typevars in unions. diff --git a/pytype/overlays/flax_overlay.py b/pytype/overlays/flax_overlay.py index 1e90d5961..eacb0cbd0 100644 --- a/pytype/overlays/flax_overlay.py +++ b/pytype/overlays/flax_overlay.py @@ -11,6 +11,7 @@ # frozen anyway we needn't bother about that for now. +from typing import TypeVar from pytype.abstract import abstract from pytype.abstract import abstract_utils from pytype.abstract import function @@ -19,11 +20,13 @@ from pytype.overlays import overlay from pytype.pytd import pytd +_T0 = TypeVar("_T0") + class DataclassOverlay(overlay.Overlay): """A custom overlay for the 'flax.struct' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { "dataclass": Dataclass.make, } @@ -34,7 +37,7 @@ def __init__(self, ctx): class Dataclass(dataclass_overlay.Dataclass): """Implements the @dataclass decorator.""" - def decorate(self, node, cls): + def decorate(self, node, cls) -> None: super().decorate(node, cls) if not isinstance(cls, abstract.InterpreterClass): return @@ -51,7 +54,7 @@ def decorate(self, node, cls): class LinenOverlay(overlay.Overlay): """A custom overlay for the 'flax.linen' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { "Module": Module, } @@ -62,7 +65,7 @@ def __init__(self, ctx): class LinenModuleOverlay(overlay.Overlay): """A custom overlay for the 'flax.linen.module' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { "Module": Module, } @@ -73,7 +76,7 @@ def __init__(self, ctx): class ModuleDataclass(dataclass_overlay.Dataclass): """Dataclass with automatic 'name' and 'parent' members.""" - def _add_implicit_field(self, node, cls_locals, key, typ): + def _add_implicit_field(self, node, cls_locals, key, typ) -> None: if key in cls_locals: self.ctx.errorlog.invalid_annotation( self.ctx.vm.frames, @@ -100,7 +103,7 @@ def make_initvar(t): self._add_implicit_field(node, cls_locals, "parent", parent_type) return cls_locals - def decorate(self, node, cls): + def decorate(self, node, cls) -> None: super().decorate(node, cls) if not isinstance(cls, abstract.InterpreterClass): return @@ -110,13 +113,13 @@ def decorate(self, node, cls): class Module(abstract.PyTDClass): """Construct a dataclass for any class inheriting from Module.""" - IMPLICIT_FIELDS = ("name", "parent") + IMPLICIT_FIELDS: tuple[str, str] = ("name", "parent") # 'Module' can also be imported through an alias in flax.linen, but we always # want to use its full, unaliased name. _MODULE = "flax.linen.module" - def __init__(self, ctx, module): + def __init__(self, ctx, module) -> None: del module # unused pytd_cls = ctx.loader.lookup_pytd(self._MODULE, "Module") # flax.linen.Module loads as a LateType, we need to convert it and then get @@ -125,7 +128,7 @@ def __init__(self, ctx, module): pytd_cls = ctx.convert.constant_to_value(pytd_cls).pytd_cls super().__init__("Module", pytd_cls, ctx) - def init_subclass(self, node, cls): + def init_subclass(self, node: _T0, cls) -> _T0: # Subclasses of Module call self.setup() when creating instances. cls.additional_init_methods.append("setup") dc = ModuleDataclass.make(self.ctx) @@ -136,7 +139,7 @@ def init_subclass(self, node, cls): def to_pytd_type_of_instance( self, node=None, instance=None, seen=None, view=None - ): + ) -> pytd.NamedType: """Get the type an instance of us would have.""" # The class is imported as flax.linen.Module but aliases # flax.linen.module.Module internally @@ -148,5 +151,5 @@ def full_name(self): # overlay because we might want to overlay other things from flax.linen. return f"{self._MODULE}.{self.name}" - def __repr__(self): + def __repr__(self) -> str: return f"Overlay({self.full_name})" diff --git a/pytype/overlays/functools_overlay.py b/pytype/overlays/functools_overlay.py index eb2cc0430..822ef62ff 100644 --- a/pytype/overlays/functools_overlay.py +++ b/pytype/overlays/functools_overlay.py @@ -9,7 +9,7 @@ class FunctoolsOverlay(overlay.Overlay): """An overlay for the functools std lib module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { "cached_property": overlay.add_name( "cached_property", special_builtins.Property.make_alias diff --git a/pytype/overlays/future_overlay.py b/pytype/overlays/future_overlay.py index 1667ba417..86ba38afc 100644 --- a/pytype/overlays/future_overlay.py +++ b/pytype/overlays/future_overlay.py @@ -7,7 +7,7 @@ class FutureUtilsOverlay(overlay.Overlay): """A custom overlay for the 'future' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { "with_metaclass": metaclass.WithMetaclass.make, } diff --git a/pytype/overlays/metaclass.py b/pytype/overlays/metaclass.py index bf4ffe0e5..a3d76f55c 100644 --- a/pytype/overlays/metaclass.py +++ b/pytype/overlays/metaclass.py @@ -5,24 +5,30 @@ # trigger them in inference. import logging +from typing import Any, TypeVar from pytype.abstract import abstract from pytype.abstract import abstract_utils from pytype.abstract import function from pytype.errors import error_types -log = logging.getLogger(__name__) + +_T0 = TypeVar("_T0") +_TAddMetaclass = TypeVar("_TAddMetaclass", bound="AddMetaclass") +_TWithMetaclass = TypeVar("_TWithMetaclass", bound="WithMetaclass") + +log: logging.Logger = logging.getLogger(__name__) class AddMetaclassInstance(abstract.BaseValue): """AddMetaclass instance (constructed by AddMetaclass.call()).""" - def __init__(self, meta, ctx, module): + def __init__(self, meta, ctx, module) -> None: super().__init__("AddMetaclassInstance", ctx) self.meta = meta self.module = module - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: del func, alias_map # unused if len(args.posargs) != 1: sig = function.Signature.from_param_names( @@ -46,10 +52,10 @@ class AddMetaclass(abstract.PyTDFunction): """Implements the add_metaclass decorator.""" @classmethod - def make(cls, ctx, module): + def make(cls: type[_TAddMetaclass], ctx, module) -> _TAddMetaclass: return super().make("add_metaclass", ctx, module) - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: """Adds a metaclass.""" del func, alias_map # unused self.match_args(node, args) @@ -64,7 +70,7 @@ def call(self, node, func, args, alias_map=None): class WithMetaclassInstance(abstract.BaseValue, abstract.Class): # pytype: disable=signature-mismatch # overriding-return-type-checks """Anonymous class created by with_metaclass.""" - def __init__(self, ctx, cls, bases): + def __init__(self, ctx, cls, bases) -> None: super().__init__("WithMetaclassInstance", ctx) abstract.Class.init_mixin(self, cls) self.bases = bases @@ -86,10 +92,10 @@ class WithMetaclass(abstract.PyTDFunction): """Implements with_metaclass.""" @classmethod - def make(cls, ctx, module): + def make(cls: type[_TWithMetaclass], ctx, module) -> _TWithMetaclass: return super().make("with_metaclass", ctx, module) - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: """Creates an anonymous class to act as a metaclass.""" del func, alias_map # unused self.match_args(node, args) diff --git a/pytype/overlays/named_tuple.py b/pytype/overlays/named_tuple.py index 60ebaf4b6..a371f1473 100644 --- a/pytype/overlays/named_tuple.py +++ b/pytype/overlays/named_tuple.py @@ -1,7 +1,7 @@ """Implementation of named tuples.""" import dataclasses -from typing import Any +from typing import TypeVar, Any from pytype import utils from pytype.abstract import abstract @@ -17,9 +17,20 @@ from pytype.pytd import pytd_utils from pytype.pytd import visitors +_T0 = TypeVar("_T0") +_TCollectionsNamedTupleBuilder = TypeVar( + "_TCollectionsNamedTupleBuilder", bound="CollectionsNamedTupleBuilder" +) +_TNamedTupleFuncBuilder = TypeVar( + "_TNamedTupleFuncBuilder", bound="NamedTupleFuncBuilder" +) +_TNamedTupleProperties = TypeVar( + "_TNamedTupleProperties", bound="NamedTupleProperties" +) + # type alias -Param = overlay_utils.Param +Param: type[overlay_utils.Param] = overlay_utils.Param # This module has classes and methods which benefit from extended docstrings, @@ -49,12 +60,14 @@ class NamedTupleProperties: bases: list[Any] @classmethod - def from_field_names(cls, name, field_names, ctx): + def from_field_names( + cls: type[_TNamedTupleProperties], name, field_names, ctx + ) -> _TNamedTupleProperties: """Make a NamedTupleProperties from field names with no types.""" fields = [Field(n, ctx.convert.unsolvable, None) for n in field_names] return cls(name, fields, []) - def validate_and_rename_fields(self, rename): + def validate_and_rename_fields(self, rename) -> None: """Validate and rename self.fields. namedtuple field names have some requirements: @@ -108,7 +121,7 @@ class _ArgsError(Exception): class _FieldMatchError(Exception): """Errors when postprocessing field args, to be converted to WrongArgTypes.""" - def __init__(self, param): + def __init__(self, param) -> None: super().__init__() self.param = param @@ -143,7 +156,7 @@ def extract_args(self, node, callargs): """ raise NotImplementedError() - def process_args(self, node, raw_args): + def process_args(self, node, raw_args) -> tuple[Any, Any]: """Convert namedtuple call args into a NamedTupleProperties. Returns both the NamedTupleProperties and an _Args struct in case the caller @@ -189,10 +202,12 @@ class CollectionsNamedTupleBuilder(_NamedTupleBuilderBase): """Factory for creating collections.namedtuple classes.""" @classmethod - def make(cls, ctx, module): + def make( + cls: type[_TCollectionsNamedTupleBuilder], ctx, module + ) -> _TCollectionsNamedTupleBuilder: return super().make("namedtuple", ctx, module) - def extract_args(self, node, callargs): + def extract_args(self, node, callargs) -> _Args: """Extracts the typename, field_names and rename arguments. collections.namedtuple takes a 'verbose' argument too but we don't care @@ -233,7 +248,7 @@ def extract_args(self, node, callargs): name=name, field_names=field_names, defaults=defaults, rename=rename ) - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[Any, Any]: """Creates a namedtuple class definition.""" # If we can't extract the arguments, we take the easy way out and return Any try: @@ -251,7 +266,7 @@ class NamedTupleFuncBuilder(_NamedTupleBuilderBase): _fields_param: error_types.BadType @classmethod - def make(cls, ctx): + def make(cls: type[_TNamedTupleFuncBuilder], ctx) -> _TNamedTupleFuncBuilder: # typing.pytd contains a NamedTuple class def and a _NamedTuple func def. self = super().make("NamedTuple", ctx, "typing", pyval_name="_NamedTuple") # NamedTuple's fields arg has type Sequence[Sequence[Union[str, type]]], @@ -264,13 +279,13 @@ def make(cls, ctx): self._fields_param = error_types.BadType(name="fields", typ=fields_type) return self - def _is_str_instance(self, val): + def _is_str_instance(self, val) -> bool: return isinstance(val, abstract.Instance) and val.full_name in ( "builtins.str", "builtins.unicode", ) - def extract_args(self, node, callargs): + def extract_args(self, node, callargs) -> _Args: """Extracts the typename and fields arguments. fields is postprocessed into field_names and field_types. @@ -314,7 +329,7 @@ def extract_args(self, node, callargs): return _Args(name=cls_name, field_names=names, field_types=types) - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[Any, Any]: try: args, props = self.process_args(node, args) except _ArgsError: @@ -335,7 +350,7 @@ class NamedTupleClassBuilder(abstract.PyTDClass): """Factory for creating typing.NamedTuples by subclassing NamedTuple.""" # attributes prohibited to set in NamedTuple class syntax - _prohibited = ( + _prohibited: tuple[str, str, str, str, str, str, str, str, str, str, str] = ( "__new__", "__init__", "__slots__", @@ -349,7 +364,7 @@ class NamedTupleClassBuilder(abstract.PyTDClass): "_source", ) - def __init__(self, ctx, module="typing"): + def __init__(self, ctx, module="typing") -> None: pyval = ctx.loader.lookup_pytd(module, "NamedTuple") super().__init__("NamedTuple", pyval, ctx) # Prior to python 3.6, NamedTuple is a function. Although NamedTuple is a @@ -384,7 +399,7 @@ def call(self, node, func, args, alias_map=None): ) return self.namedtuple.call(node, None, args, alias_map) - def make_class(self, node, bases, f_locals): + def make_class(self, node: _T0, bases, f_locals) -> tuple[Any, Any]: # If BuildClass.call() hits max depth, f_locals will be [unsolvable] # Since we don't support defining NamedTuple subclasses in a nested scope # anyway, we can just return unsolvable here to prevent a crash, and let the @@ -556,11 +571,11 @@ def make_class_from_pyi(self, cls_name, pytd_cls): class _DictBuilder: """Construct dict abstract classes for namedtuple members.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: self.ctx = ctx self.dict_cls = ctx.convert.lookup_value("builtins", "dict") - def make(self, typ): + def make(self, typ) -> abstract.ParameterizedClass: # Normally, we would use abstract_utils.K and abstract_utils.V, but # collections.pyi doesn't conform to that standard. return abstract.ParameterizedClass( @@ -571,7 +586,7 @@ def make(self, typ): class NamedTupleClass(abstract.InterpreterClass): """Named tuple classes.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # Store the original properties, to output to pyi files. self.props = None @@ -593,7 +608,7 @@ def instantiate(self, node, container=None): return inst -def _build_namedtuple(props, node, ctx): +def _build_namedtuple(props, node, ctx) -> tuple[Any, Any]: """Build an InterpreterClass representing the namedtuple.""" # TODO(mdemello): Fix this to support late types. diff --git a/pytype/overlays/overlay.py b/pytype/overlays/overlay.py index 6e444321b..4fe99e104 100644 --- a/pytype/overlays/overlay.py +++ b/pytype/overlays/overlay.py @@ -31,7 +31,7 @@ def __init__(self, ctx) real_module: An abstract.Module wrapping the AST for the underlying module. """ - def __init__(self, ctx, name, member_map, ast): + def __init__(self, ctx, name, member_map, ast) -> None: """Initialize the overlay. Args: @@ -65,7 +65,7 @@ def get_module(self, name): else: return self.real_module - def items(self): + def items(self) -> list: items = super().items() items += [ (name, item) @@ -90,11 +90,11 @@ def maybe_load_member(self, member_name): return member -def add_name(name, builder): +def add_name(name, builder) -> Callable[[Any, Any], Any]: """Turns (name, ctx, module) -> val signatures into (ctx, module) -> val.""" return lambda ctx, module: builder(name, ctx, module) -def drop_module(builder): +def drop_module(builder) -> Callable[[Any, Any], Any]: """Turns (ctx) -> val signatures into (ctx, module) -> val.""" return lambda ctx, module: builder(ctx) diff --git a/pytype/overlays/overlay_dict.py b/pytype/overlays/overlay_dict.py index 1f11c721d..288ab8c1a 100644 --- a/pytype/overlays/overlay_dict.py +++ b/pytype/overlays/overlay_dict.py @@ -28,7 +28,7 @@ # Collection of module overlays, used by the vm to fetch an overlay # instead of the module itself. Memoized in the vm itself. -overlays = { +overlays: dict[str, type] = { "abc": abc_overlay.ABCOverlay, "asyncio": asyncio_types_overlay.AsyncioOverlay, "attr": attr_overlay.AttrOverlay, diff --git a/pytype/overlays/overlay_utils.py b/pytype/overlays/overlay_utils.py index c2ca54c4e..0b58c656c 100644 --- a/pytype/overlays/overlay_utils.py +++ b/pytype/overlays/overlay_utils.py @@ -1,17 +1,26 @@ """Utilities for writing overlays.""" +from typing import TypeVar from pytype.abstract import abstract from pytype.abstract import function from pytype.pytd import pep484 from pytype.pytd import pytd from pytype.typegraph import cfg +_TParam = TypeVar("_TParam", bound="Param") + # Various types accepted by the annotations dictionary. # Runtime type checking of annotations, since if we do have an unexpected type # being stored in annotations, we should catch that as soon as possible, and add # it to the list if valid. -PARAM_TYPES = ( +PARAM_TYPES: tuple[ + type[cfg.Variable], + type[abstract.Class], + type[abstract.TypeParameter], + type[abstract.Union], + type[abstract.Unsolvable], +] = ( cfg.Variable, abstract.Class, abstract.TypeParameter, @@ -23,26 +32,26 @@ class Param: """Internal representation of method parameters.""" - def __init__(self, name, typ=None, default=None): + def __init__(self, name, typ=None, default=None) -> None: if typ: assert isinstance(typ, PARAM_TYPES), (typ, type(typ)) self.name = name self.typ = typ self.default = default - def unsolvable(self, ctx, node): + def unsolvable(self: _TParam, ctx, node) -> _TParam: """Replace None values for typ and default with unsolvable.""" self.typ = self.typ or ctx.convert.unsolvable self.default = self.default or ctx.new_unsolvable(node) return self - def __repr__(self): + def __repr__(self) -> str: return f"Param({self.name}, {self.typ!r}, {self.default!r})" class TypingContainer(abstract.AnnotationContainer): - def __init__(self, name, ctx): + def __init__(self, name, ctx) -> None: if name in pep484.TYPING_TO_BUILTIN: module = "builtins" pytd_name = pep484.TYPING_TO_BUILTIN[name] @@ -153,7 +162,7 @@ def _process_annotation(param): return decorator.call(node, func=None, args=args)[1] -def add_base_class(node, cls, base_cls): +def add_base_class(node, cls, base_cls) -> None: """Inserts base_cls into the MRO of cls.""" # The class's MRO is constructed from its bases at the moment the class is # created, so both need to be updated. diff --git a/pytype/overlays/pytype_extensions_overlay.py b/pytype/overlays/pytype_extensions_overlay.py index a2ddd20bc..f401382db 100644 --- a/pytype/overlays/pytype_extensions_overlay.py +++ b/pytype/overlays/pytype_extensions_overlay.py @@ -7,7 +7,7 @@ class PytypeExtensionsOverlay(overlay.Overlay): """A custom overlay for the 'pytype_extensions' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { "assert_type": overlay.add_name( "assert_type", special_builtins.AssertType.make_alias diff --git a/pytype/overlays/six_overlay.py b/pytype/overlays/six_overlay.py index 2af62d7ba..835634673 100644 --- a/pytype/overlays/six_overlay.py +++ b/pytype/overlays/six_overlay.py @@ -1,5 +1,7 @@ """Implementation of special members of third_party/six.""" +from collections.abc import Callable +from typing import Any from pytype.overlays import metaclass from pytype.overlays import overlay @@ -7,7 +9,7 @@ class SixOverlay(overlay.Overlay): """A custom overlay for the 'six' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { "add_metaclass": metaclass.AddMetaclass.make, "with_metaclass": metaclass.WithMetaclass.make, @@ -20,7 +22,7 @@ def __init__(self, ctx): super().__init__(ctx, "six", member_map, ast) -def build_version_bool(major): +def build_version_bool(major) -> Callable[[Any, Any], Any]: def make(ctx, module): del module # unused return ctx.convert.bool_values[ctx.python_version[0] == major] diff --git a/pytype/overlays/special_builtins.py b/pytype/overlays/special_builtins.py index a041eacbf..07dd1dc01 100644 --- a/pytype/overlays/special_builtins.py +++ b/pytype/overlays/special_builtins.py @@ -1,6 +1,7 @@ """Custom implementations of builtin types.""" import contextlib +from typing import Any, TypeVar from pytype.abstract import abstract from pytype.abstract import abstract_utils @@ -9,11 +10,15 @@ from pytype.abstract import mixin from pytype.errors import error_types +_T0 = TypeVar("_T0") +_TBuiltinClass = TypeVar("_TBuiltinClass", bound="BuiltinClass") +_TBuiltinFunction = TypeVar("_TBuiltinFunction", bound="BuiltinFunction") + class TypeNew(abstract.PyTDFunction): """Implements type.__new__.""" - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: if len(args.posargs) == 4: self.match_args(node, args) # May raise FailedFunctionCall. cls, name_var, bases_var, class_dict_var = args.posargs @@ -73,15 +78,19 @@ class BuiltinFunction(abstract.PyTDFunction): _NAME: str = None @classmethod - def make(cls, ctx): + def make(cls: type[_TBuiltinFunction], ctx) -> _TBuiltinFunction: assert cls._NAME return super().make(cls._NAME, ctx, "builtins") @classmethod - def make_alias(cls, name, ctx, module): + def make_alias( + cls: type[_TBuiltinFunction], name, ctx, module + ) -> _TBuiltinFunction: return super().make(name, ctx, module) - def get_underlying_method(self, node, receiver, method_name): + def get_underlying_method( + self, node, receiver, method_name + ) -> tuple[Any, Any]: """Get the bound method that a built-in function delegates to.""" results = [] for b in receiver.bindings: @@ -109,7 +118,7 @@ class Abs(BuiltinFunction): _NAME = "abs" - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: self.match_args(node, args) arg = args.posargs[0] node, fn = self.get_underlying_method(node, arg, "__abs__") @@ -124,7 +133,7 @@ class Next(BuiltinFunction): _NAME = "next" - def _get_args(self, args): + def _get_args(self, args) -> tuple[Any, Any]: arg = args.posargs[0] if len(args.posargs) > 1: default = args.posargs[1] @@ -134,7 +143,7 @@ def _get_args(self, args): default = self.ctx.program.NewVariable() return arg, default - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: self.match_args(node, args) arg, default = self._get_args(args) node, fn = self.get_underlying_method(node, arg, "__next__") @@ -151,7 +160,7 @@ class Round(BuiltinFunction): _NAME = "round" - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: self.match_args(node, args) node, fn = self.get_underlying_method(node, args.posargs[0], "__round__") if fn is None: @@ -170,7 +179,7 @@ class ObjectPredicate(BuiltinFunction): def run(self, node, args, result): raise NotImplementedError(self.__class__.__name__) - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: try: self.match_args(node, args) node = self.ctx.connect_new_cfg_node(node, f"CallPredicate:{self.name}") @@ -193,7 +202,7 @@ class UnaryPredicate(ObjectPredicate): def _call_predicate(self, node, obj): raise NotImplementedError(self.__class__.__name__) - def run(self, node, args, result): + def run(self, node, args, result) -> None: for obj in args.posargs[0].bindings: node, pyval = self._call_predicate(node, obj) result.AddBinding( @@ -212,7 +221,7 @@ class BinaryPredicate(ObjectPredicate): def _call_predicate(self, node, left, right): raise NotImplementedError(self.__class__.__name__) - def run(self, node, args, result): + def run(self, node, args, result) -> None: for right in abstract_utils.expand_type_parameter_instances( args.posargs[1].bindings ): @@ -251,7 +260,7 @@ class HasAttr(BinaryPredicate): def _call_predicate(self, node, left, right): return self._has_attr(node, left.data, right.data) - def _has_attr(self, node, obj, attr): + def _has_attr(self, node, obj, attr) -> tuple[Any, bool | None]: """Check if the object has attribute attr. Args: @@ -279,10 +288,10 @@ class IsInstance(BinaryPredicate): _NAME = "isinstance" - def _call_predicate(self, node, left, right): + def _call_predicate(self, node: _T0, left, right) -> tuple[_T0, Any]: return node, self._is_instance(left.data, right.data) - def _is_instance(self, obj, class_spec): + def _is_instance(self, obj, class_spec) -> bool | None: """Check if the object matches a class specification. Args: @@ -307,10 +316,10 @@ class IsSubclass(BinaryPredicate): _NAME = "issubclass" - def _call_predicate(self, node, left, right): + def _call_predicate(self, node: _T0, left, right) -> tuple[_T0, Any]: return node, self._is_subclass(left.data, right.data) - def _is_subclass(self, cls, class_spec): + def _is_subclass(self, cls, class_spec) -> bool | None: """Check if the given class is a subclass of a class specification. Args: @@ -336,7 +345,7 @@ class IsCallable(UnaryPredicate): def _call_predicate(self, node, obj): return self._is_callable(node, obj) - def _is_callable(self, node, obj): + def _is_callable(self, node: _T0, obj) -> tuple[Any, bool | None]: """Check if the object is callable. Args: @@ -373,17 +382,19 @@ class BuiltinClass(abstract.PyTDClass): _NAME: str = None @classmethod - def make(cls, ctx): + def make(cls: type[_TBuiltinClass], ctx) -> _TBuiltinClass: assert cls._NAME return cls(cls._NAME, ctx, "builtins") @classmethod - def make_alias(cls, name, ctx, module): + def make_alias( + cls: type[_TBuiltinClass], name, ctx, module + ) -> _TBuiltinClass: # Although this method has the same signature as __init__, it makes alias # creation more readable and consistent with BuiltinFunction. return cls(name, ctx, module) - def __init__(self, name, ctx, module): + def __init__(self, name, ctx, module) -> None: super().__init__(name, ctx.loader.lookup_pytd(module, name), ctx) self.module = module @@ -391,7 +402,7 @@ def __init__(self, name, ctx, module): class SuperInstance(abstract.BaseValue): """The result of a super() call, i.e., a lookup proxy.""" - def __init__(self, cls, obj, ctx): + def __init__(self, cls, obj, ctx) -> None: super().__init__("super", ctx) self.cls = self.ctx.convert.super_type self.super_cls = cls @@ -426,7 +437,7 @@ def get_special_attribute(self, node, name, valself): else: return super().get_special_attribute(node, name, valself) - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: self.ctx.errorlog.not_callable(self.ctx.vm.frames, self) return node, self.ctx.new_unsolvable(node) @@ -435,10 +446,12 @@ class Super(BuiltinClass): """The super() function. Calling it will create a SuperInstance.""" # Minimal signature, only used for constructing exceptions. - _SIGNATURE = function.Signature.from_param_names("super", ("cls", "self")) + _SIGNATURE: function.Signature = function.Signature.from_param_names( + "super", ("cls", "self") + ) _NAME = "super" - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: result = self.ctx.program.NewVariable() num_args = len(args.posargs) if num_args == 0: @@ -502,7 +515,7 @@ class Object(BuiltinClass): _NAME = "object" - def is_object_new(self, func): + def is_object_new(self, func) -> bool: """Whether the given function is object.__new__. Args: @@ -564,7 +577,7 @@ class RevealType(BuiltinFunction): _NAME = "reveal_type" - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: for a in args.posargs: self.ctx.errorlog.reveal_type(self.ctx.vm.frames, node, a) return node, self.ctx.convert.build_none(node) @@ -574,12 +587,12 @@ class AssertType(BuiltinFunction): """For debugging. assert_type(x, t) asserts that the type of "x" is "t".""" # Minimal signature, only used for constructing exceptions. - _SIGNATURE = function.Signature.from_param_names( + _SIGNATURE: function.Signature = function.Signature.from_param_names( "assert_type", ("variable", "type") ) _NAME = "assert_type" - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: if len(args.posargs) == 2: var, typ = args.posargs else: @@ -613,14 +626,14 @@ def call(self, node, func, args, alias_map=None): class Property(BuiltinClass): """Property decorator.""" - _KEYS = ["fget", "fset", "fdel", "doc"] + _KEYS: list[str] = ["fget", "fset", "fdel", "doc"] _NAME = "property" - def signature(self): + def signature(self) -> function.Signature: # Minimal signature, only used for constructing exceptions. return function.Signature.from_param_names(self.name, tuple(self._KEYS)) - def _get_args(self, args): + def _get_args(self, args) -> dict: ret = dict(zip(self._KEYS, args.posargs)) for k, v in args.namedargs.items(): if k not in self._KEYS: @@ -630,14 +643,14 @@ def _get_args(self, args): ret[k] = v return ret - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: property_args = self._get_args(args) return node, PropertyInstance( self.ctx, self.name, self, **property_args ).to_variable(node) -def _is_fn_abstract(func_var): +def _is_fn_abstract(func_var) -> bool: if func_var is None: return False return any(getattr(d, "is_abstract", False) for d in func_var.data) @@ -646,7 +659,9 @@ def _is_fn_abstract(func_var): class PropertyInstance(abstract.Function, mixin.HasSlots): """Property instance (constructed by Property.call()).""" - def __init__(self, ctx, name, cls, fget=None, fset=None, fdel=None, doc=None): + def __init__( + self, ctx, name, cls, fget=None, fset=None, fdel=None, doc=None + ) -> None: super().__init__("property", ctx) mixin.HasSlots.init_mixin(self) self.name = name # Reports the correct decorator in error messages. @@ -672,7 +687,7 @@ def __init__(self, ctx, name, cls, fget=None, fset=None, fdel=None, doc=None): self.is_method = True self.bound_class = abstract.BoundFunction - def fget_slot(self, node, obj, objtype): + def fget_slot(self, node, obj, objtype) -> tuple[Any, Any]: obj_val = abstract_utils.get_atomic_value( obj, default=self.ctx.convert.unsolvable ) @@ -695,38 +710,38 @@ def fget_slot(self, node, obj, objtype): self.ctx, node, self.fget, function.Args((obj,)) ) - def fset_slot(self, node, obj, value): + def fset_slot(self, node, obj, value) -> tuple[Any, Any]: return function.call_function( self.ctx, node, self.fset, function.Args((obj, value)) ) - def fdelete_slot(self, node, obj): + def fdelete_slot(self, node, obj) -> tuple[Any, Any]: return function.call_function( self.ctx, node, self.fdel, function.Args((obj,)) ) - def getter_slot(self, node, fget): + def getter_slot(self, node: _T0, fget) -> tuple[_T0, Any]: prop = PropertyInstance( self.ctx, self.name, self.cls, fget, self.fset, self.fdel, self.doc ) result = self.ctx.program.NewVariable([prop], fget.bindings, node) return node, result - def setter_slot(self, node, fset): + def setter_slot(self, node: _T0, fset) -> tuple[_T0, Any]: prop = PropertyInstance( self.ctx, self.name, self.cls, self.fget, fset, self.fdel, self.doc ) result = self.ctx.program.NewVariable([prop], fset.bindings, node) return node, result - def deleter_slot(self, node, fdel): + def deleter_slot(self, node: _T0, fdel) -> tuple[_T0, Any]: prop = PropertyInstance( self.ctx, self.name, self.cls, self.fget, self.fset, fdel, self.doc ) result = self.ctx.program.NewVariable([prop], fdel.bindings, node) return node, result - def update_signature_scope(self, cls): + def update_signature_scope(self, cls) -> None: for fvar in (self.fget, self.fset, self.fdel): if fvar: for f in fvar.data: @@ -734,7 +749,7 @@ def update_signature_scope(self, cls): f.update_signature_scope(cls) -def _check_method_decorator_arg(fn_var, name, ctx): +def _check_method_decorator_arg(fn_var, name, ctx) -> bool: """Check that @classmethod or @staticmethod are applied to a function.""" for d in fn_var.data: try: @@ -750,7 +765,7 @@ def _check_method_decorator_arg(fn_var, name, ctx): class StaticMethodInstance(abstract.Function, mixin.HasSlots): """StaticMethod instance (constructed by StaticMethod.call()).""" - def __init__(self, ctx, cls, func): + def __init__(self, ctx, cls, func) -> None: super().__init__("staticmethod", ctx) mixin.HasSlots.init_mixin(self) self.func = func @@ -760,7 +775,7 @@ def __init__(self, ctx, cls, func): self.is_method = True self.bound_class = abstract.BoundFunction - def func_slot(self, node, obj, objtype): + def func_slot(self, node: _T0, obj, objtype) -> tuple[_T0, Any]: return node, self.func @@ -768,10 +783,12 @@ class StaticMethod(BuiltinClass): """Static method decorator.""" # Minimal signature, only used for constructing exceptions. - _SIGNATURE = function.Signature.from_param_names("staticmethod", ("func",)) + _SIGNATURE: function.Signature = function.Signature.from_param_names( + "staticmethod", ("func",) + ) _NAME = "staticmethod" - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: if len(args.posargs) != 1: raise error_types.WrongArgCount(self._SIGNATURE, args, self.ctx) arg = args.posargs[0] @@ -787,7 +804,7 @@ class ClassMethodCallable(abstract.BoundFunction): class ClassMethodInstance(abstract.Function, mixin.HasSlots): """ClassMethod instance (constructed by ClassMethod.call()).""" - def __init__(self, ctx, cls, func): + def __init__(self, ctx, cls, func) -> None: super().__init__("classmethod", ctx) mixin.HasSlots.init_mixin(self) self.cls = cls @@ -797,11 +814,11 @@ def __init__(self, ctx, cls, func): self.is_method = True self.bound_class = ClassMethodCallable - def func_slot(self, node, obj, objtype): + def func_slot(self, node: _T0, obj, objtype) -> tuple[_T0, Any]: results = [ClassMethodCallable(objtype, b.data) for b in self.func.bindings] return node, self.ctx.program.NewVariable(results, [], node) - def update_signature_scope(self, cls): + def update_signature_scope(self, cls) -> None: for f in self.func.data: if isinstance(f, abstract.Function): f.update_signature_scope(cls) @@ -811,10 +828,12 @@ class ClassMethod(BuiltinClass): """Class method decorator.""" # Minimal signature, only used for constructing exceptions. - _SIGNATURE = function.Signature.from_param_names("classmethod", ("func",)) + _SIGNATURE: function.Signature = function.Signature.from_param_names( + "classmethod", ("func",) + ) _NAME = "classmethod" - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: if len(args.posargs) != 1: raise error_types.WrongArgCount(self._SIGNATURE, args, self.ctx) arg = args.posargs[0] @@ -853,7 +872,7 @@ class Type(BuiltinClass, mixin.HasSlots): _NAME = "type" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) mixin.HasSlots.init_mixin(self) slot = self.ctx.convert.convert_pytd_function( diff --git a/pytype/overlays/subprocess_overlay.py b/pytype/overlays/subprocess_overlay.py index c50e4c5f9..668f9b52c 100644 --- a/pytype/overlays/subprocess_overlay.py +++ b/pytype/overlays/subprocess_overlay.py @@ -8,7 +8,7 @@ class SubprocessOverlay(overlay.Overlay): """A custom overlay for the 'subprocess' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { "Popen": Popen, } @@ -55,7 +55,7 @@ def _can_match_multiple(self, args): class Popen(abstract.PyTDClass, mixin.HasSlots): """Custom implementation of subprocess.Popen.""" - def __init__(self, ctx, module): + def __init__(self, ctx, module) -> None: pytd_cls = ctx.loader.lookup_pytd(module, "Popen") super().__init__("Popen", pytd_cls, ctx) mixin.HasSlots.init_mixin(self) diff --git a/pytype/overlays/sys_overlay.py b/pytype/overlays/sys_overlay.py index 201915fb7..40a0168b7 100644 --- a/pytype/overlays/sys_overlay.py +++ b/pytype/overlays/sys_overlay.py @@ -7,7 +7,7 @@ class SysOverlay(overlay.Overlay): """A custom overlay for the 'sys' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: member_map = { "platform": overlay.drop_module(build_platform), "version_info": overlay.drop_module(build_version_info), @@ -18,7 +18,13 @@ def __init__(self, ctx): class VersionInfo(abstract.Tuple): - ATTRIBUTES = ("major", "minor", "micro", "releaselevel", "serial") + ATTRIBUTES: tuple[str, str, str, str, str] = ( + "major", + "minor", + "micro", + "releaselevel", + "serial", + ) def get_special_attribute(self, node, name, valself): try: @@ -32,7 +38,7 @@ def build_platform(ctx): return ctx.convert.constant_to_value(ctx.options.platform) -def build_version_info(ctx): +def build_version_info(ctx) -> VersionInfo: """Build sys.version_info.""" version = [] # major, minor diff --git a/pytype/overlays/typed_dict.py b/pytype/overlays/typed_dict.py index 534d522d9..f3b219b0d 100644 --- a/pytype/overlays/typed_dict.py +++ b/pytype/overlays/typed_dict.py @@ -1,6 +1,7 @@ """Implementation of TypedDict.""" import dataclasses +from typing import Any, TypeVar from pytype.abstract import abstract from pytype.abstract import abstract_utils @@ -11,6 +12,11 @@ from pytype.pytd import pytd +_T0 = TypeVar("_T0") +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") + + def _is_required(value: abstract.BaseValue) -> bool | None: name = value.full_name if name == "typing.Required": @@ -38,7 +44,7 @@ def keys(self): def optional(self): return self.keys - self.required - def add(self, k, v, total): + def add(self, k, v, total) -> None: """Adds key and value.""" req = _is_required(v) if req is None: @@ -52,7 +58,7 @@ def add(self, k, v, total): if required: self.required.add(k) - def check_keys(self, keys): + def check_keys(self, keys) -> tuple[Any, set]: keys = set(keys) missing = (self.keys - keys) & self.required extra = keys - self.keys @@ -62,7 +68,7 @@ def check_keys(self, keys): class TypedDictBuilder(abstract.PyTDClass): """Factory for creating typing.TypedDict classes.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: pyval = ctx.loader.lookup_pytd("typing", "TypedDict") super().__init__("TypedDict", pyval, ctx) # Signature for the functional constructor @@ -73,7 +79,7 @@ def __init__(self, ctx): self.ctx, "typing.TypedDict", sig ) - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: """Call the functional constructor.""" props = self._extract_args(args) cls = TypedDictClass(props, self, self.ctx) @@ -88,7 +94,7 @@ def _extract_param(self, args, pos, name, pyval_type, typ): bad = error_types.BadType(name, typ) raise error_types.WrongArgTypes(self.fn_sig, args, self.ctx, bad) from e - def _extract_args(self, args): + def _extract_args(self, args) -> TypedDictProperties: if len(args.posargs) != 2: raise error_types.WrongArgCount(self.fn_sig, args, self.ctx) name = self._extract_param(args, 0, "name", str, self.ctx.convert.str_type) @@ -114,7 +120,7 @@ def _extract_args(self, args): props.add(k, value, total) return props - def _validate_bases(self, cls_name, bases): + def _validate_bases(self, cls_name, bases) -> None: """Check that all base classes are valid.""" for base_var in bases: for base in base_var.data: @@ -126,7 +132,7 @@ def _validate_bases(self, cls_name, bases): self.ctx.vm.frames, base_var, details ) - def _merge_base_class_fields(self, bases, props): + def _merge_base_class_fields(self, bases, props) -> None: """Add the merged list of base class fields to the fields dict.""" # Updates props in place, raises an error if a duplicate key is encountered. provenance = {k: props.name for k in props.fields} @@ -145,7 +151,7 @@ def _merge_base_class_fields(self, bases, props): props.add(k, v, base.props.total) provenance[k] = base.name - def make_class(self, node, bases, f_locals, total): + def make_class(self, node: _T0, bases, f_locals, total) -> tuple[_T0, Any]: # If BuildClass.call() hits max depth, f_locals will be [unsolvable] # See comment in NamedTupleClassBuilder.make_class(); equivalent logic # applies here. @@ -193,7 +199,7 @@ def make_class(self, node, bases, f_locals, total): cls_var = cls.to_variable(node) return node, cls_var - def make_class_from_pyi(self, cls_name, pytd_cls): + def make_class_from_pyi(self, cls_name, pytd_cls) -> "TypedDictClass": """Make a TypedDictClass from a pyi class.""" # NOTE: Returns the abstract class, not a variable. name = pytd_cls.name or cls_name @@ -223,16 +229,16 @@ def make_class_from_pyi(self, cls_name, pytd_cls): class TypedDictClass(abstract.PyTDClass): """A template for typed dicts.""" - def __init__(self, props, base_cls, ctx): + def __init__(self, props, base_cls, ctx) -> None: self.props = props self._base_cls = base_cls # TypedDictBuilder for constructing subclasses super().__init__(props.name, ctx.convert.dict_type.pytd_cls, ctx) self.init_method = self._make_init(props) - def __repr__(self): + def __repr__(self) -> str: return f"TypedDictClass({self.name})" - def _make_init(self, props): + def _make_init(self, props) -> abstract.SimpleFunction: # __init__ method for type checking signatures. # We construct this here and pass it to TypedDictClass because we need # access to abstract.SimpleFunction. @@ -247,7 +253,7 @@ def _make_init(self, props): } return abstract.SimpleFunction(sig, self.ctx) - def _new_instance(self, container, node, args): + def _new_instance(self, container, node, args) -> "TypedDict": self.init_method.match_and_map_args(node, args, None) ret = TypedDict(self.props, self.ctx) for k, v in args.namedargs.items(): @@ -278,7 +284,7 @@ class TypedDict(abstract.Dict): a regular dict. """ - def __init__(self, props, ctx): + def __init__(self, props, ctx) -> None: super().__init__(ctx) self.props = props self.set_native_slot("__delitem__", self.delitem_slot) @@ -292,15 +298,17 @@ def fields(self): def class_name(self): return self.props.name - def __repr__(self): + def __repr__(self) -> str: return f"" - def _check_str_key(self, name): + def _check_str_key(self, name: _T0) -> _T0: if name not in self.fields: raise error_types.TypedDictKeyMissing(self, name) return name - def _check_str_key_value(self, node, name, value_var): + def _check_str_key_value( + self, node, name: _T1, value_var: _T2 + ) -> tuple[_T1, _T2]: self._check_str_key(name) typ = self.fields[name] bad = self.ctx.matcher(node).compute_one_match(value_var, typ).bad_matches @@ -323,37 +331,37 @@ def _check_key(self, name_var): raise error_types.TypedDictKeyMissing(self, None) from e return self._check_str_key(name) - def _check_value(self, node, name_var, value_var): + def _check_value(self, node, name_var, value_var: _T2) -> _T2: """Check that value has the right type.""" # We have already called check_key so name is in fields name = abstract_utils.get_atomic_python_constant(name_var, str) self._check_str_key_value(node, name, value_var) return value_var - def getitem_slot(self, node, name_var): + def getitem_slot(self, node, name_var) -> tuple[Any, Any]: # A typed dict getitem should have a concrete string arg. If we have a var # with multiple bindings just fall back to Any. self._check_key(name_var) return super().getitem_slot(node, name_var) - def setitem_slot(self, node, name_var, value_var): + def setitem_slot(self, node, name_var, value_var) -> tuple[Any, Any]: self._check_key(name_var) self._check_value(node, name_var, value_var) return super().setitem_slot(node, name_var, value_var) - def set_str_item(self, node, name, value_var): + def set_str_item(self, node: _T0, name, value_var) -> _T0: self._check_str_key_value(node, name, value_var) return super().set_str_item(node, name, value_var) - def delitem_slot(self, node, name_var): + def delitem_slot(self, node, name_var) -> tuple[Any, Any]: self._check_key(name_var) return self.call_pytd(node, "__delitem__", name_var) - def pop_slot(self, node, key_var, default_var=None): + def pop_slot(self, node, key_var, default_var=None) -> tuple[Any, Any]: self._check_key(key_var) return super().pop_slot(node, key_var, default_var) - def get_slot(self, node, key_var, default_var=None): + def get_slot(self, node: _T0, key_var, default_var=None) -> tuple[_T0, Any]: try: str_key = self._check_key(key_var) except error_types.TypedDictKeyMissing: @@ -366,7 +374,7 @@ def get_slot(self, node, key_var, default_var=None): # here, or just `default | None`? return node, default_var or self.ctx.convert.none.to_variable(node) - def merge_instance_type_parameter(self, node, name, value): + def merge_instance_type_parameter(self, node, name, value) -> None: _, _, short_name = name.rpartition(".") if short_name == abstract_utils.K: expected_length = 1 @@ -382,14 +390,14 @@ def merge_instance_type_parameter(self, node, name, value): def _is_typeddict(val: abstract.BaseValue): if isinstance(val, abstract.Union): - return all(_is_typeddict(v) for v in val.options) + return all(_is_typeddict(v) for v in val.options) # pytype: disable=attribute-error return isinstance(val, TypedDictClass) class IsTypedDict(abstract.PyTDFunction): """Implementation of typing.is_typeddict.""" - def call(self, node, func, args, alias_map=None): + def call(self, node: _T0, func, args, alias_map=None) -> tuple[_T0, Any]: self.match_args(node, args) if args.posargs: tp = args.posargs[0] @@ -410,9 +418,15 @@ def call(self, node, func, args, alias_map=None): class _TypedDictItemRequiredness(overlay_utils.TypingContainer): """typing.(Not)Required.""" - _REQUIREDNESS = None + _REQUIREDNESS: None = None - def _get_value_info(self, inner, ellipses, allowed_ellipses=frozenset()): + def _get_value_info( + self, inner, ellipses, allowed_ellipses=frozenset() + ) -> tuple[ + tuple[int | str, ...], + tuple[abstract.BaseValue, ...], + type[abstract.ParameterizedClass], + ]: template, processed_inner, abstract_class = super()._get_value_info( inner, ellipses, allowed_ellipses ) diff --git a/pytype/overlays/typing_extensions_overlay.py b/pytype/overlays/typing_extensions_overlay.py index a4891ae36..37f4eee53 100644 --- a/pytype/overlays/typing_extensions_overlay.py +++ b/pytype/overlays/typing_extensions_overlay.py @@ -1,16 +1,17 @@ """Implementation of special members of typing_extensions.""" from pytype.overlays import typing_overlay +from pytype.typegraph import cfg class TypingExtensionsOverlay(typing_overlay.Redirect): """A custom overlay for the 'typing_extensions' module.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: aliases = {"runtime": "typing.runtime_checkable"} super().__init__("typing_extensions", aliases, ctx) - def _convert_member(self, name, member, subst=None): + def _convert_member(self, name, member, subst=None) -> cfg.Variable: var = super()._convert_member(name, member, subst) for val in var.data: # typing_extensions backports typing features to older versions. diff --git a/pytype/overlays/typing_overlay.py b/pytype/overlays/typing_overlay.py index 7e725a1b7..d460da76f 100644 --- a/pytype/overlays/typing_overlay.py +++ b/pytype/overlays/typing_overlay.py @@ -4,11 +4,13 @@ # pylint: disable=unpacking-non-sequence import abc +from collections.abc import Callable, Sequence from typing import ( Dict as _Dict, Optional as _Optional, Tuple as _Tuple, Type as _Type, + Any ) from pytype import utils @@ -26,7 +28,7 @@ # type alias -Param = overlay_utils.Param +Param: type[overlay_utils.Param] = overlay_utils.Param def _is_typing_container(cls: pytd.Class): @@ -43,7 +45,7 @@ class TypingOverlay(overlay.Overlay): to import a typing member in a too-low runtime version. """ - def __init__(self, ctx): + def __init__(self, ctx) -> None: # Make sure we have typing available as a dependency member_map = typing_overlay.copy() ast = ctx.loader.typing @@ -83,7 +85,7 @@ def _convert_member( class Redirect(overlay.Overlay): """Base class for overlays that redirect to typing.""" - def __init__(self, module, aliases, ctx): + def __init__(self, module, aliases, ctx) -> None: assert all(v.startswith("typing.") for v in aliases.values()) member_map = { k: _builder_from_name(v[len("typing.") :]) for k, v in aliases.items() @@ -105,7 +107,9 @@ def __init__(self, module, aliases, ctx): super().__init__(ctx, module, member_map, ast) -def _builder_from_name(name): +def _builder_from_name( + name, +) -> Callable[[Any, Any], Any]: def resolve(ctx, module): del module # unused pytd_val = ctx.loader.lookup_pytd("typing", name) @@ -117,7 +121,7 @@ def resolve(ctx, module): return resolve -def _builder(name, builder): +def _builder(name, builder) -> Callable[[Any, Any], Any]: """Turns (name, ctx) -> val signatures into (ctx, module) -> val.""" return lambda ctx, module: builder(name, ctx) @@ -125,10 +129,10 @@ def _builder(name, builder): class Union(abstract.AnnotationClass): """Implementation of typing.Union[...].""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: super().__init__("Union", ctx) - def _build_value(self, node, inner, ellipses): + def _build_value(self, node, inner, ellipses) -> abstract.Union: self.ctx.errorlog.invalid_ellipses(self.ctx.vm.frames, ellipses, self.name) return abstract.Union(inner, self.ctx) @@ -148,7 +152,7 @@ def _build_value(self, node, inner, ellipses): class Final(abstract.AnnotationClass): """Implementation of typing.Final[T].""" - def _build_value(self, node, inner, ellipses): + def _build_value(self, node, inner, ellipses) -> abstract.FinalAnnotation: self.ctx.errorlog.invalid_ellipses(self.ctx.vm.frames, ellipses, self.name) if len(inner) != 1: error = "typing.Final must wrap a single type" @@ -163,7 +167,9 @@ def instantiate(self, node, container=None): class Tuple(overlay_utils.TypingContainer): """Implementation of typing.Tuple.""" - def _get_value_info(self, inner, ellipses, allowed_ellipses=frozenset()): + def _get_value_info( + self, inner, ellipses, allowed_ellipses=frozenset() + ): if ellipses: # An ellipsis may appear at the end of the parameter list as long as it is # not the only parameter. @@ -179,7 +185,7 @@ def _get_value_info(self, inner, ellipses, allowed_ellipses=frozenset()): class Callable(overlay_utils.TypingContainer): """Implementation of typing.Callable[...].""" - def getitem_slot(self, node, slice_var): + def getitem_slot(self, node, slice_var) -> tuple: content = abstract_utils.maybe_extract_tuple(slice_var) inner, ellipses = self._build_inner(content) args = inner[0] @@ -219,7 +225,9 @@ def getitem_slot(self, node, slice_var): value = self._build_value(node, tuple(inner), ellipses) return node, value.to_variable(node) - def _get_value_info(self, inner, ellipses, allowed_ellipses=frozenset()): + def _get_value_info( + self, inner, ellipses, allowed_ellipses=frozenset() + ): if isinstance(inner[0], list): template = list(range(len(inner[0]))) + [ t.name for t in self.base_cls.template @@ -242,7 +250,7 @@ def _get_value_info(self, inner, ellipses, allowed_ellipses=frozenset()): class TypeVarError(Exception): """Raised if an error is encountered while initializing a TypeVar.""" - def __init__(self, message, bad_call=None): + def __init__(self, message, bad_call=None) -> None: super().__init__(message) self.bad_call = bad_call @@ -288,7 +296,7 @@ def _get_typeparam_name(self, node, args): args.posargs[0], "name", str, arg_type_desc="a constant str" ) - def _get_typeparam_args(self, node, args): + def _get_typeparam_args(self, node, args) -> tuple[tuple, Any, Any, Any]: constraints = tuple( self._get_annotation(node, c, "constraint") for c in args.posargs[1:] ) @@ -308,7 +316,7 @@ def _get_typeparam_args(self, node, args): raise TypeVarError("ambiguous **kwargs not allowed") return constraints, bound, covariant, contravariant - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple: """Call typing.TypeVar().""" args = args.simplify(node, self.ctx) try: @@ -328,7 +336,7 @@ def call(self, node, func, args, alias_map=None): class TypeVar(_TypeVariable): """Representation of typing.TypeVar, as a function.""" - _ABSTRACT_CLASS = abstract.TypeParameter + _ABSTRACT_CLASS: type[abstract.TypeParameter] = abstract.TypeParameter @classmethod def make(cls, ctx, module): @@ -354,7 +362,7 @@ def _get_namedarg(self, node, args, name, default_value): class ParamSpec(_TypeVariable): """Representation of typing.ParamSpec, as a function.""" - _ABSTRACT_CLASS = abstract.ParamSpec + _ABSTRACT_CLASS: type[abstract.ParamSpec] = abstract.ParamSpec @classmethod def make(cls, ctx, module): @@ -375,7 +383,7 @@ def _get_namedarg(self, node, args, name, default_value): class Cast(abstract.PyTDFunction): """Implements typing.cast.""" - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: if args.posargs: _, value = self.ctx.annotation_utils.extract_and_init_annotation( node, "typing.cast", args.posargs[0] @@ -387,7 +395,7 @@ def call(self, node, func, args, alias_map=None): class Never(abstract.Singleton): """Implements typing.Never as a singleton.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: super().__init__("Never", ctx) # Sets cls to Type so that runtime usages of Never don't cause pytype to # think that Never is being used illegally in type annotations. @@ -397,7 +405,7 @@ def __init__(self, ctx): class NewType(abstract.PyTDFunction): """Implementation of typing.NewType as a function.""" - def __init__(self, name, signatures, kind, decorators, ctx): + def __init__(self, name, signatures, kind, decorators, ctx) -> None: super().__init__(name, signatures, kind, decorators, ctx) assert len(self.signatures) == 1, "NewType has more than one signature." signature = self.signatures[0].signature @@ -411,7 +419,7 @@ def internal_name_counter(self): self._internal_name_counter += 1 return val - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple[Any, Any]: args = args.simplify(node, self.ctx) self.match_args(node, args, match_all_views=True) # As long as the types match we do not really care about the actual @@ -454,7 +462,8 @@ def call(self, node, func, args, alias_map=None): class Overload(abstract.PyTDFunction): """Implementation of typing.overload.""" - def call(self, node, func, args, alias_map=None): + def call( + self, node, func, args, alias_map=None) -> tuple: """Marks that the given function is an overload.""" del func, alias_map # unused self.match_args(node, args) @@ -480,7 +489,7 @@ def make(cls, ctx, module): del module return super().make("final", ctx, "typing") - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple: """Marks that the given function is final.""" del func, alias_map # unused self.match_args(node, args) @@ -503,7 +512,9 @@ def _can_be_final(self, obj): class Generic(overlay_utils.TypingContainer): """Implementation of typing.Generic.""" - def _get_value_info(self, inner, ellipses, allowed_ellipses=frozenset()): + def _get_value_info( + self, inner, ellipses, allowed_ellipses=frozenset() + ) -> tuple[Sequence[str], Sequence, type[abstract.ParameterizedClass]]: template, inner = abstract_utils.build_generic_template(inner, self) return template, inner, abstract.ParameterizedClass @@ -511,7 +522,7 @@ def _get_value_info(self, inner, ellipses, allowed_ellipses=frozenset()): class Optional(abstract.AnnotationClass): """Implementation of typing.Optional.""" - def _build_value(self, node, inner, ellipses): + def _build_value(self, node, inner, ellipses) -> Union: self.ctx.errorlog.invalid_ellipses(self.ctx.vm.frames, ellipses, self.name) if len(inner) != 1: error = "typing.Optional can only contain one type parameter" @@ -561,7 +572,7 @@ def _build_value(self, node, inner, ellipses): class Concatenate(abstract.AnnotationClass): """Implementation of typing.Concatenate[...].""" - def _build_value(self, node, inner, ellipses): + def _build_value(self, node, inner, ellipses) -> abstract.Concatenate: self.ctx.errorlog.invalid_ellipses(self.ctx.vm.frames, ellipses, self.name) return abstract.Concatenate(list(inner), self.ctx) @@ -569,11 +580,11 @@ def _build_value(self, node, inner, ellipses): class ForwardRef(abstract.PyTDClass): """Implementation of typing.ForwardRef.""" - def __init__(self, ctx, module): + def __init__(self, ctx, module) -> None: pyval = ctx.loader.lookup_pytd(module, "ForwardRef") super().__init__("ForwardRef", pyval, ctx) - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple: # From https://docs.python.org/3/library/typing.html#typing.ForwardRef: # Class used for internal typing representation of string forward # references. [...] ForwardRef should not be instantiated by a user @@ -591,7 +602,7 @@ def call(self, node, func, args, alias_map=None): class DataclassTransformBuilder(abstract.PyTDFunction): """Minimal implementation of typing.dataclass_transform.""" - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple: del func, alias_map # unused # We are not yet doing anything with the args but since we have a type # signature available we might as well check it. @@ -607,10 +618,10 @@ def call(self, node, func, args, alias_map=None): class DataclassTransform(abstract.SimpleValue): """Minimal implementation of typing.dataclass_transform.""" - def __init__(self, ctx): + def __init__(self, ctx) -> None: super().__init__("", ctx) - def call(self, node, func, args, alias_map=None): + def call(self, node, func, args, alias_map=None) -> tuple: del func, alias_map # unused arg = args.posargs[0] for d in arg.data: @@ -650,7 +661,7 @@ def build_re_member(ctx, module): # name -> lowest_supported_version -_unsupported_members = { +_unsupported_members: dict[str, tuple[int, int]] = { "LiteralString": (3, 11), "TypeVarTuple": (3, 11), "Unpack": (3, 11), @@ -658,7 +669,7 @@ def build_re_member(ctx, module): # name -> (builder, lowest_supported_version) -typing_overlay = { +typing_overlay: dict[Any, tuple[Any, tuple[int, int] | None]] = { "Annotated": (_builder("Annotated", Annotated), (3, 9)), "Any": (overlay.drop_module(build_any), None), "Callable": (_builder("Callable", Callable), None), diff --git a/pytype/overriding_checks.py b/pytype/overriding_checks.py index e596f6c7f..3b22aa780 100644 --- a/pytype/overriding_checks.py +++ b/pytype/overriding_checks.py @@ -11,10 +11,10 @@ from pytype.abstract import function from pytype.pytd import pytd -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) # This should be context.Context, which can't be imported due to a circular dep. -_ContextType = Any +_ContextType: Any = Any _SignatureMapType = Mapping[str, function.Signature] @@ -53,7 +53,7 @@ def _get_varargs_annotation_type(param_type): def _check_positional_parameter_annotations( method_signature, base_signature, is_subtype -): +) -> SignatureError | None: """Checks type annotations for positional parameters of the overriding method. Args: @@ -213,7 +213,7 @@ def _check_positional_parameters( def _check_keyword_only_parameters( method_signature, base_signature, is_subtype -): +) -> SignatureError | None: """Checks that the keyword-only parameters of the overriding method match. Args: @@ -308,7 +308,9 @@ def _check_keyword_only_parameters( return None -def _check_default_values(method_signature, base_signature): +def _check_default_values( + method_signature, base_signature +) -> SignatureError | None: """Checks that default parameter values of the overriding method match. Args: @@ -363,7 +365,9 @@ def _check_default_values(method_signature, base_signature): return None -def _check_return_types(method_signature, base_signature, is_subtype): +def _check_return_types( + method_signature, base_signature, is_subtype +) -> SignatureError | None: """Checks that the return types match.""" try: base_return_type = base_signature.annotations["return"] @@ -389,7 +393,7 @@ def _check_return_types(method_signature, base_signature, is_subtype): def _check_signature_compatible( method_signature, base_signature, stack, matcher, ctx -): +) -> None: """Checks if the signatures match for the overridden and overriding methods. Adds the first error found to the context's error log. @@ -523,7 +527,7 @@ def _get_parameterized_class_signature_map( return method_signature_map -def check_overriding_members(cls, bases, members, matcher, ctx): +def check_overriding_members(cls, bases, members, matcher, ctx) -> None: """Check that the method signatures of the new class match base classes.""" # Maps method names to methods. diff --git a/pytype/pattern_matching.py b/pytype/pattern_matching.py index 65533821f..b3a6bb552 100644 --- a/pytype/pattern_matching.py +++ b/pytype/pattern_matching.py @@ -3,7 +3,7 @@ import collections import dataclasses import enum -from typing import Optional, Union, cast +from typing import Any, Generator, TypeVar, Optional, Union, cast from pytype.abstract import abstract from pytype.abstract import abstract_utils @@ -11,6 +11,8 @@ from pytype.pytd import slots from pytype.typegraph import cfg +_T_MatchTypes = TypeVar("_T_MatchTypes", bound="_MatchTypes") + # Type aliases @@ -85,7 +87,7 @@ def __init__(self, typ=None): def is_empty(self) -> bool: return not (self.values or self.indefinite) - def __repr__(self): + def __repr__(self) -> str: indef = "*" if self.indefinite else "" return f"" @@ -93,22 +95,22 @@ def __repr__(self): class _OptionSet: """Holds a set of options.""" - def __init__(self): + def __init__(self) -> None: # Collection of options, stored as a dict rather than a set so we can find a # given option efficiently. self._options: dict[abstract.Class, _Option] = {} - def __iter__(self): + def __iter__(self) -> Generator[None, Any, None]: yield from self._options.values() - def __bool__(self): + def __bool__(self) -> bool: return not self.is_complete @property def is_complete(self) -> bool: - return all(x.is_empty for x in self) + return all(x.is_empty for x in self) # pytype: disable=attribute-error - def add_instance(self, val): + def add_instance(self, val) -> None: """Add an instance to the match options.""" cls = val.cls if cls not in self._options: @@ -118,7 +120,7 @@ def add_instance(self, val): else: self.add_type(cls) - def add_type(self, cls): + def add_type(self, cls) -> None: """Add an class to the match options.""" if cls not in self._options: self._options[cls] = _Option(cls) @@ -169,7 +171,7 @@ def cover_type(self, val) -> list[_Value]: class _OptionTracker: """Tracks a set of match options.""" - def __init__(self, match_var, ctx): + def __init__(self, match_var, ctx) -> None: self.match_var: cfg.Variable = match_var self.ctx = ctx self.options: _OptionSet = _OptionSet() @@ -197,8 +199,8 @@ def get_narrowed_match_var(self, node) -> cfg.Variable: else: narrowed = [] for opt in self.options: - if not opt.is_empty: - narrowed.append(opt.typ.instantiate(node)) + if not opt.is_empty: # pytype: disable=attribute-error + narrowed.append(opt.typ.instantiate(node)) # pytype: disable=attribute-error return self.ctx.join_variables(node, narrowed) def cover(self, line, var) -> list[_Value]: @@ -241,7 +243,7 @@ def cover_from_none(self, line) -> list[_Value]: self.cases[line].add_type(cls) return self.options.cover_type(cls) - def invalidate(self): + def invalidate(self) -> None: self.is_valid = False @@ -265,7 +267,7 @@ def make(cls, op: opcodes.Opcode): class _Matches: """Tracks branches of match statements.""" - def __init__(self, ast_matches): + def __init__(self, ast_matches) -> None: self.start_to_end = {} # match_line : match_end_line self.end_to_starts = collections.defaultdict(list) self.match_cases = {} # opcode_line : match_line @@ -276,7 +278,7 @@ def __init__(self, ast_matches): for m in ast_matches.matches: self._add_match(m.start, m.end, m.cases) - def _add_match(self, start, end, cases): + def _add_match(self, start, end, cases) -> None: self.start_to_end[start] = end self.end_to_starts[end].append(start) self.unseen_cases[start] = {c.start for c in cases} @@ -288,11 +290,11 @@ def _add_match(self, start, end, cases): if c.as_name: self.as_names[c.end] = c.as_name - def register_case(self, match_line, case_line): + def register_case(self, match_line, case_line) -> None: assert self.match_cases[case_line] == match_line self.unseen_cases[match_line].discard(case_line) - def __repr__(self): + def __repr__(self) -> str: return f""" Matches: {sorted(self.start_to_end.items())} Cases: {self.match_cases} @@ -311,7 +313,7 @@ class IncompleteMatch: class BranchTracker: """Track exhaustiveness in pattern matches.""" - def __init__(self, ast_matches, ctx): + def __init__(self, ast_matches, ctx) -> None: self.matches = _Matches(ast_matches) self._option_tracker: dict[int, dict[int, _OptionTracker]] = ( collections.defaultdict(dict) @@ -362,7 +364,7 @@ def instantiate_case_var(self, op, match_var, node): tracker = self._get_option_tracker(match_var, match_line) if tracker.cases[op.line]: # We have matched on one or more classes in this case. - types = [x.typ for x in tracker.cases[op.line]] + types = [x.typ for x in tracker.cases[op.line]] # pytype: disable=attribute-error return self._make_instance_for_match(node, types) else: # We have not matched on a type, just bound the current match var to a @@ -508,7 +510,7 @@ def check_ending( for tracker in trackers.values(): if tracker.is_valid: for o in tracker.options: - if not o.is_empty and not o.indefinite: - ret.append(IncompleteMatch(start, o.values)) + if not o.is_empty and not o.indefinite: # pytype: disable=attribute-error + ret.append(IncompleteMatch(start, o.values)) # pytype: disable=attribute-error self._active_ends -= done return ret diff --git a/pytype/platform_utils/path_utils.py b/pytype/platform_utils/path_utils.py index 555d09c92..709208bf4 100644 --- a/pytype/platform_utils/path_utils.py +++ b/pytype/platform_utils/path_utils.py @@ -7,11 +7,16 @@ import glob as glob_module import os import sys +from typing import TypeVar + +AnyOrLiteralStr = TypeVar('AnyOrLiteralStr', str, bytes, str) +AnyStr = TypeVar('AnyStr', str, bytes) +_T0 = TypeVar('_T0') if sys.platform == 'win32': import ctypes # pylint: disable=g-import-not-at-top - def _short_path_to_long_path(path: str): + def _short_path_to_long_path(path: str) -> str: """Convert to long path names in win32.""" buffer = ctypes.create_unicode_buffer(0) required_size = ctypes.windll.kernel32.GetLongPathNameW(path, buffer, 0) @@ -27,11 +32,11 @@ def _short_path_to_long_path(path: str): else: - def _short_path_to_long_path(path: str): + def _short_path_to_long_path(path: str) -> str: return path -def _replace_driver_code(path: str): +def _replace_driver_code(path: str) -> str: drive, other = os.path.splitdrive(path) drive = drive.capitalize() return os.path.join(drive, other) @@ -39,18 +44,18 @@ def _replace_driver_code(path: str): if sys.platform == 'win32': - def standardize_return_path(path): + def standardize_return_path(path: _T0) -> _T0: path = _replace_driver_code(path) path = _short_path_to_long_path(path) return path else: - def standardize_return_path(path): + def standardize_return_path(path: _T0) -> _T0: return path -def _standardize_return_path_wrapper(func): +def _standardize_return_path_wrapper(func: _T0) -> _T0: """Standardize return path in win32.""" if sys.platform == 'win32': diff --git a/pytype/preprocess.py b/pytype/preprocess.py index fe4ca4410..2c8a64de1 100644 --- a/pytype/preprocess.py +++ b/pytype/preprocess.py @@ -1,6 +1,9 @@ """Preprocess source code before compilation.""" import ast +from typing import TypeVar + +_T0 = TypeVar("_T0") # pylint: disable=invalid-name @@ -11,11 +14,11 @@ def __init__(self): self.annotation_lines = [] self.in_function = False - def visit_AnnAssign(self, node): + def visit_AnnAssign(self, node) -> None: if self.in_function and node.value is None: self.annotation_lines.append(node.end_lineno - 1) # change to 0-based - def visit_FunctionDef(self, node): + def visit_FunctionDef(self, node) -> None: self.in_function = True for n in node.body: self.visit(n) @@ -25,7 +28,7 @@ def visit_FunctionDef(self, node): # pylint: enable=invalid-name -def augment_annotations(src): +def augment_annotations(src: _T0) -> _T0: """Add an assignment to bare variable annotations.""" try: tree = ast.parse(src) diff --git a/pytype/pyc/compile_bytecode.py b/pytype/pyc/compile_bytecode.py index 00136fb04..ed43edba2 100644 --- a/pytype/pyc/compile_bytecode.py +++ b/pytype/pyc/compile_bytecode.py @@ -7,17 +7,18 @@ import re import sys -MAGIC = importlib.util.MAGIC_NUMBER + +MAGIC: bytes = importlib.util.MAGIC_NUMBER # This pattern is as per PEP-263. ENCODING_PATTERN = "^[ \t\v]*#.*?coding[:=][ \t]*([-_.a-zA-Z0-9]+)" -def is_comment_only(line): +def is_comment_only(line) -> bool: return re.match("[ \t\v]*#.*", line) is not None -def _write32(f, w): +def _write32(f, w) -> None: f.write( bytearray( [(w >> 0) & 0xFF, (w >> 8) & 0xFF, (w >> 16) & 0xFF, (w >> 24) & 0xFF] @@ -25,7 +26,7 @@ def _write32(f, w): ) -def write_pyc(f, codeobject, source_size=0, timestamp=0): +def write_pyc(f, codeobject, source_size=0, timestamp=0) -> None: f.write(MAGIC) f.write(b"\r\n\0\0") _write32(f, timestamp) @@ -33,14 +34,14 @@ def write_pyc(f, codeobject, source_size=0, timestamp=0): f.write(marshal.dumps(codeobject)) -def compile_to_pyc(data_file, filename, output, mode): +def compile_to_pyc(data_file, filename, output, mode) -> None: """Compile the source code to byte code.""" with open(data_file, encoding="utf-8") as fi: src = fi.read() compile_src_to_pyc(src, filename, output, mode) -def strip_encoding(src): +def strip_encoding(src: str) -> str: """Strip encoding from a src string assumed to be read from a file.""" # Python 2's compile function does not like the line specifying the encoding. # So, we strip it off if it is present, replacing it with an empty comment to @@ -59,7 +60,7 @@ def strip_encoding(src): return src -def compile_src_to_pyc(src, filename, output, mode): +def compile_src_to_pyc(src, filename, output, mode) -> None: """Compile a string of source code.""" try: codeobject = compile(src, filename, mode) @@ -71,7 +72,7 @@ def compile_src_to_pyc(src, filename, output, mode): write_pyc(output, codeobject) -def main(): +def main() -> None: if len(sys.argv) != 4: sys.exit(1) output = sys.stdout.buffer if hasattr(sys.stdout, "buffer") else sys.stdout diff --git a/pytype/pyc/compiler.py b/pytype/pyc/compiler.py index 9bcb4a7c8..cef805b4e 100644 --- a/pytype/pyc/compiler.py +++ b/pytype/pyc/compiler.py @@ -22,10 +22,10 @@ # This would mean that when -V3.10 is passed to pytype, it will use the exe at # pytype/python3.10 to compile the code under analysis. Remember to add the new # file to the pytype_main_deps target! -_CUSTOM_PYTHON_EXES = {} +_CUSTOM_PYTHON_EXES: dict[None, None] = {} _COMPILE_SCRIPT = "pyc/compile_bytecode.py" -_COMPILE_ERROR_RE = re.compile(r"^(.*) \((.*), line (\d+)\)$") +_COMPILE_ERROR_RE: re.Pattern = re.compile(r"^(.*) \((.*), line (\d+)\)$") class PythonNotFoundError(Exception): @@ -35,7 +35,7 @@ class PythonNotFoundError(Exception): class CompileError(Exception): """A compilation error.""" - def __init__(self, msg): + def __init__(self, msg) -> None: super().__init__(msg) match = _COMPILE_ERROR_RE.match(msg) if match: @@ -50,7 +50,7 @@ def __init__(self, msg): def compile_src_string_to_pyc_string( src, filename, python_version, python_exe: list[str], mode="exec" -): +) -> bytes: """Compile Python source code to pyc data. This may use py_compile if the src is for the same version as we're running, @@ -182,7 +182,7 @@ def _get_python_exe_version(python_exe: list[str]): return _parse_exe_version_string(python_exe_version) -def _parse_exe_version_string(version_str): +def _parse_exe_version_string(version_str) -> tuple | None: """Parse the version string of a Python executable. Arguments: diff --git a/pytype/pyc/generate_opcode_diffs.py b/pytype/pyc/generate_opcode_diffs.py index f42d88afb..d0b1dd5c4 100644 --- a/pytype/pyc/generate_opcode_diffs.py +++ b/pytype/pyc/generate_opcode_diffs.py @@ -32,16 +32,21 @@ import sys import tempfile import textwrap +from typing import Any # Starting with Python 3.12, `dis` collections contain pseudo instructions and # instrumented instructions. These are opcodes with values >= MIN_PSEUDO_OPCODE # and >= MIN_INSTRUMENTED_OPCODE. # Pytype doesn't care about those, so we ignore them here. -_MIN_INSTRUMENTED_OPCODE = getattr(opcode, 'MIN_INSTRUMENTED_OPCODE', 237) +_MIN_INSTRUMENTED_OPCODE: Any = getattr(opcode, 'MIN_INSTRUMENTED_OPCODE', 237) -def generate_diffs(argv): +def generate_diffs( + argv, +) -> tuple[ + list[list[str]], list, list, list[str], list[str], Any, Any, list[str] +]: """Generate diffs.""" version1, version2 = argv @@ -223,7 +228,7 @@ def is_unset(opname_entry): ) -def _get_arg_type(dis, opname): +def _get_arg_type(dis, opname) -> str | None: all_types = ['CONST', 'NAME', 'JREL', 'JABS', 'LOCAL', 'FREE', 'NARGS'] for t in all_types: k = f'HAS_{t}' @@ -232,7 +237,7 @@ def _get_arg_type(dis, opname): return None -def _diff_intrinsic_descs(old, new, new_version): +def _diff_intrinsic_descs(old, new, new_version) -> tuple[list, list[list[str]]]: """Diff intrinsic descriptions and returns mapping and stubs if they differ.""" if old == new: return [], [] @@ -259,7 +264,7 @@ def _get_inline_cache_entries(dis, opname): ) -def main(argv): +def main(argv) -> None: ( classes, stubs, diff --git a/pytype/pyc/opcodes.py b/pytype/pyc/opcodes.py index 0907d1277..208fc50d1 100644 --- a/pytype/pyc/opcodes.py +++ b/pytype/pyc/opcodes.py @@ -1,12 +1,14 @@ """Opcode definitions.""" -from typing import cast +from typing import Any, TypeVar, cast import attrs from pycnite import bytecode import pycnite.types from typing_extensions import override +_TOpcode = TypeVar("_TOpcode", bound="Opcode") + # We define all-uppercase classes, to match their opcode names: # pylint: disable=invalid-name @@ -58,7 +60,7 @@ class Opcode: ) _FLAGS = 0 - def __init__(self, index, line, endline=None, col=None, endcol=None): + def __init__(self, index, line, endline=None, col=None, endcol=None) -> None: self.index = index self.line = line self.endline = endline @@ -75,7 +77,7 @@ def __init__(self, index, line, endline=None, col=None, endcol=None): self.push_exc_block = False self.pop_exc_block = False - def at_line(self, line): + def at_line(self: _TOpcode, line) -> _TOpcode: """Return a new opcode similar to this one but with a different line.""" # Ignore the optional slots (prev, next, block_target). op = Opcode(self.index, line) @@ -83,7 +85,7 @@ def at_line(self, line): op.code = self.code return op - def basic_str(self): + def basic_str(self) -> str: """Helper function for the various __str__ formats.""" folded = "<<<<" if self.folded else "" return "%d: %d: %s %s" % ( @@ -99,7 +101,7 @@ def __str__(self): else: return self.basic_str() - def __repr__(self): + def __repr__(self) -> str: return self.__class__.__name__ @property @@ -109,63 +111,63 @@ def name(self): @classmethod def for_python_version( cls, version: tuple[int, int] # pylint: disable=unused-argument - ): + ) -> type[_TOpcode]: return cls @classmethod - def has_const(cls): + def has_const(cls) -> bool: return bool(cls._FLAGS & HAS_CONST) @classmethod - def has_name(cls): + def has_name(cls) -> bool: return bool(cls._FLAGS & HAS_NAME) @classmethod - def has_jrel(cls): + def has_jrel(cls) -> bool: return bool(cls._FLAGS & HAS_JREL) @classmethod - def has_jabs(cls): + def has_jabs(cls) -> bool: return bool(cls._FLAGS & HAS_JABS) @classmethod - def has_known_jump(cls): + def has_known_jump(cls) -> bool: return bool(cls._FLAGS & (HAS_JREL | HAS_JABS)) @classmethod - def has_junknown(cls): + def has_junknown(cls) -> bool: return bool(cls._FLAGS & HAS_JUNKNOWN) @classmethod - def has_jump(cls): + def has_jump(cls) -> bool: return bool(cls._FLAGS & (HAS_JREL | HAS_JABS | HAS_JUNKNOWN)) @classmethod - def has_local(cls): + def has_local(cls) -> bool: return bool(cls._FLAGS & HAS_LOCAL) @classmethod - def has_free(cls): + def has_free(cls) -> bool: return bool(cls._FLAGS & HAS_FREE) @classmethod - def has_nargs(cls): + def has_nargs(cls) -> bool: return bool(cls._FLAGS & HAS_NARGS) @classmethod - def has_argument(cls): + def has_argument(cls) -> bool: return bool(cls._FLAGS & HAS_ARGUMENT) @classmethod - def no_next(cls): + def no_next(cls) -> bool: return bool(cls._FLAGS & NO_NEXT) @classmethod - def carry_on_to_next(cls): + def carry_on_to_next(cls) -> bool: return not cls._FLAGS & NO_NEXT @classmethod - def store_jump(cls): + def store_jump(cls) -> bool: return bool(cls._FLAGS & STORE_JUMP) @classmethod @@ -173,11 +175,11 @@ def does_jump(cls): return cls.has_jump() and not cls.store_jump() @classmethod - def pushes_block(cls): + def pushes_block(cls) -> bool: return bool(cls._FLAGS & PUSHES_BLOCK) @classmethod - def pops_block(cls): + def pops_block(cls) -> bool: return bool(cls._FLAGS & POPS_BLOCK) @@ -192,12 +194,12 @@ class OpcodeWithArg(Opcode): __slots__ = ("arg", "argval") - def __init__(self, index, line, endline, col, endcol, arg, argval): + def __init__(self, index, line, endline, col, endcol, arg, argval) -> None: super().__init__(index, line, endline, col, endcol) self.arg = arg self.argval = argval - def __str__(self): + def __str__(self) -> str: out = f"{self.basic_str()} {self.argval}" if self.annotation: return f"{out} # type: {self.annotation}" @@ -1145,7 +1147,7 @@ class LOAD_FROM_DICT_OR_DEREF(OpcodeWithArg): def _make_opcodes( ops: list[pycnite.types.Opcode], python_version: tuple[int, int] -): +) -> dict[int, Any]: """Convert pycnite opcodes to pytype opcodes.""" g = globals() offset_to_op = {} @@ -1159,7 +1161,7 @@ def _make_opcodes( return offset_to_op -def _add_exception_block(offset_to_op, e): +def _add_exception_block(offset_to_op, e) -> None: """Adds opcodes marking an exception block.""" start_op = offset_to_op[e.start] setup_op = SETUP_EXCEPT_311( @@ -1182,7 +1184,7 @@ def _add_exception_block(offset_to_op, e): offset_to_op[end + 0.5] = pop_op -def _get_exception_bitmask(offset_to_op, exception_ranges): +def _get_exception_bitmask(offset_to_op, exception_ranges) -> int: """Get a bitmask for whether an offset is in an exception range.""" in_exception = 0 pos = 1 @@ -1200,7 +1202,9 @@ def _get_exception_bitmask(offset_to_op, exception_ranges): # Opcodes that come up as exception targets but don't need a block. -_IGNORED_EXCEPTION_TARGETS = ( +_IGNORED_EXCEPTION_TARGETS: tuple[ + type[END_ASYNC_FOR], type[CLEANUP_THROW], type[SWAP] +] = ( # In 3.11+ `async for` loops end normally by thowing a StopAsyncIteration # exception, which jumps to an END_ASYNC_FOR opcode via the exception table. END_ASYNC_FOR, @@ -1216,7 +1220,7 @@ def _get_exception_bitmask(offset_to_op, exception_ranges): def _add_setup_except( offset_to_op: dict[float, Opcode], exc_table: pycnite.types.ExceptionTable -): +) -> None: """Handle the exception table in 3.11+.""" # In python 3.11, exception handling is no longer bytecode-based - see # https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt @@ -1269,7 +1273,7 @@ def _add_setup_except( def _get_opcode_following_cleanup_throw_jump_pairs( op_items: list[tuple[int, Opcode]], start_i: int -): +) -> Opcode | None: for i in range(start_i, len(op_items), 2): if ( isinstance(op_items[i][1], CLEANUP_THROW) @@ -1283,7 +1287,7 @@ def _get_opcode_following_cleanup_throw_jump_pairs( def _should_elide_opcode( op_items: list[tuple[int, Opcode]], i: int, python_version: tuple[int, int] -): +) -> bool: """Returns `True` if the opcode on index `i` should be elided. Opcodes should be elided if they don't contribute to type checking and cause @@ -1333,7 +1337,9 @@ def _should_elide_opcode( return False -def _make_opcode_list(offset_to_op, python_version: tuple[int, int]): +def _make_opcode_list( + offset_to_op, python_version: tuple[int, int] +) -> tuple[list, dict[Any, int]]: """Convert opcodes to a list and fill in opcode.index, next and prev.""" ops = [] offset_to_index = {} @@ -1359,7 +1365,7 @@ def _make_opcode_list(offset_to_op, python_version: tuple[int, int]): return ops, offset_to_index -def _add_jump_targets(ops, offset_to_index): +def _add_jump_targets(ops, offset_to_index) -> None: """Map the target of jump instructions to the opcode they jump to.""" for op in ops: op = cast(OpcodeWithArg, op) diff --git a/pytype/pyc/pyc.py b/pytype/pyc/pyc.py index f025d2f4b..98b928912 100644 --- a/pytype/pyc/pyc.py +++ b/pytype/pyc/pyc.py @@ -2,20 +2,23 @@ import abc import copy +from typing import TypeVar from pycnite import pyc from pytype import utils from pytype.pyc import compiler +_T0 = TypeVar("_T0") + # Reexport since we have exposed this error publicly as pyc.CompileError -CompileError = compiler.CompileError +CompileError: type[compiler.CompileError] = compiler.CompileError # The abstract base class for a code visitor passed to pyc.visit. class CodeVisitor(abc.ABC): - def __init__(self): + def __init__(self) -> None: # This cache, used by pyc.visit below, is needed to avoid visiting the same # code object twice, since some visitors mutate the input object. # It maps CodeType object id to the result of visiting that object. @@ -41,11 +44,11 @@ def parse_pyc_string(data): class AdjustFilename(CodeVisitor): """Visitor for changing co_filename in a code object.""" - def __init__(self, filename): + def __init__(self, filename) -> None: super().__init__() self.filename = filename - def visit_code(self, code): + def visit_code(self, code: _T0) -> _T0: code.co_filename = self.filename return code diff --git a/pytype/pyi/classdef.py b/pytype/pyi/classdef.py index d6f1595c5..7bbc67548 100644 --- a/pytype/pyi/classdef.py +++ b/pytype/pyi/classdef.py @@ -7,7 +7,7 @@ from pytype.pyi import types from pytype.pytd import pytd -_ParseError = types.ParseError +_ParseError: type[types.ParseError] = types.ParseError def get_bases( diff --git a/pytype/pyi/conditions.py b/pytype/pyi/conditions.py index c565da7b9..287dc6fb2 100644 --- a/pytype/pyi/conditions.py +++ b/pytype/pyi/conditions.py @@ -2,17 +2,19 @@ import ast as astlib +from typing import Any + from pytype.ast import visitor as ast_visitor from pytype.pyi import types from pytype.pytd import slots as cmp_slots -_ParseError = types.ParseError +_ParseError: type[types.ParseError] = types.ParseError class ConditionEvaluator(ast_visitor.BaseVisitor): """Evaluates if statements in pyi files.""" - def __init__(self, options): + def __init__(self, options) -> None: super().__init__(ast=astlib) self._compares = { astlib.Eq: cmp_slots.EQ, @@ -83,7 +85,7 @@ def fail(self, name=None): msg += "Supported checks are sys.platform and sys.version_info" raise _ParseError(msg) - def visit_Attribute(self, node): + def visit_Attribute(self, node) -> bool | str | None: if not isinstance(node.value, astlib.Name): self.fail() name = f"{node.value.id}.{node.attr}" @@ -94,7 +96,7 @@ def visit_Attribute(self, node): return bool(getattr(self._options, node.attr)) self.fail(name) - def visit_Slice(self, node): + def visit_Slice(self, node) -> slice: return slice(node.lower, node.upper, node.step) def visit_Index(self, node): @@ -103,13 +105,13 @@ def visit_Index(self, node): def visit_Pyval(self, node): return node.value - def visit_Subscript(self, node): + def visit_Subscript(self, node) -> tuple[Any, Any]: return (node.value, node.slice) - def visit_Tuple(self, node): + def visit_Tuple(self, node) -> tuple: return tuple(node.elts) - def visit_BoolOp(self, node): + def visit_BoolOp(self, node) -> bool: if isinstance(node.op, astlib.Or): return any(node.values) elif isinstance(node.op, astlib.And): @@ -123,7 +125,7 @@ def visit_UnaryOp(self, node): else: raise _ParseError(f"Unexpected unary operator: {node.op}") - def visit_Compare(self, node): + def visit_Compare(self, node) -> bool: if isinstance(node.left, tuple): ident = node.left else: @@ -137,7 +139,7 @@ def evaluate(test: astlib.AST, options) -> bool: return ConditionEvaluator(options).visit(test) -def _is_int_tuple(value): +def _is_int_tuple(value) -> bool: """Return whether the value is a tuple of integers.""" return isinstance(value, tuple) and all(isinstance(v, int) for v in value) diff --git a/pytype/pyi/definitions.py b/pytype/pyi/definitions.py index 0312ae740..66b3566e8 100644 --- a/pytype/pyi/definitions.py +++ b/pytype/pyi/definitions.py @@ -22,20 +22,24 @@ from pytype.pytd.parse import parser_constants # Typing members that represent sets of types. -_TYPING_SETS = ("typing.Intersection", "typing.Optional", "typing.Union") +_TYPING_SETS: tuple[str, str, str] = ( + "typing.Intersection", + "typing.Optional", + "typing.Union", +) -_ParseError = types.ParseError +_ParseError: type[types.ParseError] = types.ParseError _NodeT = TypeVar("_NodeT", bound=pytd.Node) class _DuplicateDefsError(Exception): - def __init__(self, duplicates): + def __init__(self, duplicates) -> None: super().__init__() self._duplicates = duplicates - def to_parse_error(self, namespace): + def to_parse_error(self, namespace) -> types.ParseError: return _ParseError( f"Duplicate attribute name(s) in {namespace}: " + ", ".join(self._duplicates) @@ -210,7 +214,7 @@ def _pytd_annotated(parameters: list[Any]) -> pytd.Type: class _InsertTypeParameters(visitors.Visitor): """Visitor for inserting TypeParameter instances.""" - def __init__(self, type_params): + def __init__(self, type_params) -> None: super().__init__() self.type_params = {p.name: p for p in type_params} @@ -224,14 +228,14 @@ def VisitNamedType(self, node): class _VerifyMutators(visitors.Visitor): """Visitor for verifying TypeParameters used in mutations are in scope.""" - def __init__(self): + def __init__(self) -> None: super().__init__() # A stack of type parameters introduced into the scope. The top of the stack # contains the currently accessible parameter set. self.type_params_in_scope = [set()] self.current_function = None - def _AddParams(self, params): + def _AddParams(self, params) -> None: top = self.type_params_in_scope[-1] self.type_params_in_scope.append(top | params) @@ -239,16 +243,16 @@ def _GetTypeParameters(self, node): params = pytd_utils.GetTypeParameters(node) return {x.name for x in params} - def EnterClass(self, node): + def EnterClass(self, node) -> None: params = set() for cls in node.bases: params |= self._GetTypeParameters(cls) self._AddParams(params) - def LeaveClass(self, _): + def LeaveClass(self, _) -> None: self.type_params_in_scope.pop() - def EnterFunction(self, node): + def EnterFunction(self, node) -> None: self.current_function = node params = set() for sig in node.signatures: @@ -260,11 +264,11 @@ def EnterFunction(self, node): params |= self._GetTypeParameters(sig.starstarargs.type) self._AddParams(params) - def LeaveFunction(self, _): + def LeaveFunction(self, _) -> None: self.type_params_in_scope.pop() self.current_function = None - def EnterParameter(self, node): + def EnterParameter(self, node) -> None: if isinstance(node.mutated_type, pytd.GenericType): params = self._GetTypeParameters(node.mutated_type) extra = params - self.type_params_in_scope[-1] @@ -279,17 +283,17 @@ def EnterParameter(self, node): class _ContainsAnyType(visitors.Visitor): """Check if a pytd object contains a type of any of the given names.""" - def __init__(self, type_names): + def __init__(self, type_names) -> None: super().__init__() self._type_names = set(type_names) self.found = False - def EnterNamedType(self, node): + def EnterNamedType(self, node) -> None: if node.name in self._type_names: self.found = True -def _contains_any_type(ast, type_names): +def _contains_any_type(ast, type_names) -> bool: """Convenience wrapper for _ContainsAnyType.""" out = _ContainsAnyType(type_names) ast.Visit(out) @@ -302,17 +306,17 @@ class _PropertyToConstant(visitors.Visitor): type_param_names: list[str] const_properties: list[list[pytd.Function]] - def EnterTypeDeclUnit(self, node): + def EnterTypeDeclUnit(self, node) -> None: self.type_param_names = [x.name for x in node.type_params] self.const_properties = [] - def LeaveTypeDeclUnit(self, node): + def LeaveTypeDeclUnit(self, node) -> None: self.type_param_names = None - def EnterClass(self, node): + def EnterClass(self, node) -> None: self.const_properties.append([]) - def LeaveClass(self, node): + def LeaveClass(self, node) -> None: self.const_properties.pop() def VisitClass(self, node): @@ -326,7 +330,7 @@ def VisitClass(self, node): methods = [x for x in node.methods if x not in self.const_properties[-1]] return node.Replace(constants=tuple(constants), methods=tuple(methods)) - def EnterFunction(self, node): + def EnterFunction(self, node) -> None: if ( self.const_properties and node.kind == pytd.MethodKind.PROPERTY @@ -334,7 +338,7 @@ def EnterFunction(self, node): ): self.const_properties[-1].append(node) - def _is_parametrised(self, method): + def _is_parametrised(self, method) -> bool | None: for sig in method.signatures: # 'method' is definitely parametrised if its return type contains a type # parameter defined in the current TypeDeclUnit. It's also likely @@ -351,9 +355,11 @@ def _is_parametrised(self, method): class Definitions: """Collect definitions used to build a TypeDeclUnit.""" - ELLIPSIS = types.Ellipsis() # Object to signal ELLIPSIS as a parameter. + ELLIPSIS: Ellipsis = ( + types.Ellipsis() + ) # Object to signal ELLIPSIS as a parameter. - def __init__(self, module_info): + def __init__(self, module_info) -> None: self.module_info = module_info self.type_map: dict[str, Any] = {} self.constants = [] @@ -364,7 +370,7 @@ def __init__(self, module_info): self.generated_classes = collections.defaultdict(list) self.module_path_map = {} - def add_alias_or_constant(self, alias_or_constant): + def add_alias_or_constant(self, alias_or_constant) -> None: """Add an alias or constant. Args: @@ -390,7 +396,9 @@ def new_type_from_value(self, value): else: return None - def new_alias_or_constant(self, name, value): + def new_alias_or_constant( + self, name, value + ) -> pytd.Alias | pytd.Constant: """Build an alias or constant.""" typ = self.new_type_from_value(value) if typ: @@ -398,7 +406,7 @@ def new_alias_or_constant(self, name, value): else: return pytd.Alias(name, value) - def new_new_type(self, name, typ): + def new_new_type(self, name, typ) -> pytd.NamedType: """Returns a type for a NewType.""" args = [("self", pytd.AnythingType()), ("val", typ)] ret = pytd.NamedType("NoneType") @@ -422,7 +430,7 @@ def new_new_type(self, name, typ): self.generated_classes[name].append(cls) return pytd.NamedType(cls_name) - def new_named_tuple(self, base_name, fields): + def new_named_tuple(self, base_name, fields) -> pytd.NamedType: """Return a type for a named tuple (implicitly generates a class). Args: @@ -437,7 +445,7 @@ def new_named_tuple(self, base_name, fields): self.add_import("typing", ["NamedTuple"]) return pytd.NamedType(nt.name) - def new_typed_dict(self, name, items, keywords): + def new_typed_dict(self, name, items, keywords) -> pytd.NamedType: """Returns a type for a TypedDict. This method is called only for TypedDict objects defined via the following @@ -482,7 +490,7 @@ def new_typed_dict(self, name, items, keywords): self.add_import("typing", ["TypedDict"]) return pytd.NamedType(cls_name) - def add_type_variable(self, name, tvar): + def add_type_variable(self, name, tvar) -> None: """Add a type variable definition.""" if tvar.kind == "TypeVar": pytd_type = pytd.TypeParameter @@ -506,7 +514,7 @@ def add_type_variable(self, name, tvar): ) ) - def add_import(self, from_package, import_list): + def add_import(self, from_package, import_list) -> None: """Add an import. Args: @@ -579,10 +587,10 @@ def _matches_named_type(self, t, names): return False return self.matches_type(t.name, names) - def _is_empty_tuple(self, t): + def _is_empty_tuple(self, t) -> bool: return isinstance(t, pytd.TupleType) and not t.parameters - def _is_heterogeneous_tuple(self, t): + def _is_heterogeneous_tuple(self, t) -> bool: return isinstance(t, pytd.TupleType) def _is_builtin_or_typing_member(self, t): @@ -593,7 +601,9 @@ def _is_builtin_or_typing_member(self, t): module == "typing" and name in pep484.ALL_TYPING_NAMES ) - def _check_for_illegal_parameters(self, base_type, parameters, is_callable): + def _check_for_illegal_parameters( + self, base_type, parameters, is_callable + ) -> None: if not self._is_builtin_or_typing_member(base_type): # TODO(b/217789659): We can only check builtin and typing names for now, # since `...` can fill in for a ParamSpec and `[]` can be used to @@ -609,7 +619,7 @@ def _check_for_illegal_parameters(self, base_type, parameters, is_callable): ): raise _ParseError("Unexpected list parameter") - def _remove_unsupported_features(self, parameters, is_callable): + def _remove_unsupported_features(self, parameters, is_callable) -> tuple: """Returns a copy of 'parameters' with unsupported features removed.""" processed_parameters = [] for p in parameters: @@ -841,7 +851,7 @@ def build_class( template=(), ) - def _adjust_self_var(self, fully_qualified_class_name, methods): + def _adjust_self_var(self, fully_qualified_class_name, methods) -> list: """Replaces typing.Self with a TypeVar.""" # TODO(b/224600845): Currently, this covers only Self used in a method # parameter or return annotation. @@ -958,7 +968,7 @@ def finalize_ast(ast: pytd.TypeDeclUnit): return ast -def _check_module_functions(functions): +def _check_module_functions(functions) -> None: """Validate top-level module functions.""" # module.__getattr__ should have a unique signature g = [f for f in functions if f.name == "__getattr__"] diff --git a/pytype/pyi/evaluator.py b/pytype/pyi/evaluator.py index 1f7391def..22ea697b7 100644 --- a/pytype/pyi/evaluator.py +++ b/pytype/pyi/evaluator.py @@ -12,7 +12,7 @@ from pytype.pyi import types -_NUM_TYPES = (int, float, complex) +_NUM_TYPES: tuple[type[int], type[float], type[complex]] = (int, float, complex) # pylint: disable=invalid-unary-operand-type diff --git a/pytype/pyi/function.py b/pytype/pyi/function.py index 98737fedc..9a726a279 100644 --- a/pytype/pyi/function.py +++ b/pytype/pyi/function.py @@ -11,7 +11,7 @@ from pytype.pytd.codegen import function as pytd_function from pytype.pytd.parse import parser_constants -_ParseError = types.ParseError +_ParseError: type[types.ParseError] = types.ParseError class Mutator(visitors.Visitor): @@ -26,7 +26,7 @@ def f(x: old_type): This visitor applies the body "x = new_type" to the function signature. """ - def __init__(self, name, new_type): + def __init__(self, name, new_type) -> None: super().__init__() self.name = name self.new_type = new_type @@ -43,7 +43,7 @@ def VisitParameter(self, p): else: return p - def __repr__(self): + def __repr__(self) -> str: return f"Mutator<{self.name} -> {self.new_type}>" __str__ = __repr__ @@ -191,7 +191,7 @@ def _pytd_star_param(arg: astlib.arg) -> pytd.Parameter | None: unpack = parser_constants.EXTERNAL_NAME_PREFIX + "typing_extensions.Unpack" if ( isinstance(arg.annotation, pytd.GenericType) - and arg.annotation.base_type.name == unpack + and arg.annotation.base_type.name == unpack # pytype: disable=attribute-error ): arg.annotation = pytd.AnythingType() return pytd_function.pytd_star_param(arg.arg, arg.annotation) # pytype: disable=wrong-arg-types diff --git a/pytype/pyi/metadata.py b/pytype/pyi/metadata.py index 5bfc46911..2568ff48a 100644 --- a/pytype/pyi/metadata.py +++ b/pytype/pyi/metadata.py @@ -15,10 +15,12 @@ """ import dataclasses -from typing import Any +from typing import Any, TypeVar from pytype.pyi import evaluator +_TCall = TypeVar("_TCall", bound="Call") + @dataclasses.dataclass class Call: @@ -29,21 +31,21 @@ class Call: kwargs: dict[str, Any] @classmethod - def from_metadata(cls, md, posarg_names, kwarg_names): + def from_metadata(cls: type[_TCall], md, posarg_names, kwarg_names) -> _TCall: fn = md["tag"] posargs = [md[k] for k in posarg_names if k in md] kwargs = {k: md[k] for k in kwarg_names if k in md} return cls(fn, posargs, kwargs) @classmethod - def from_call_dict(cls, md): + def from_call_dict(cls: type[_TCall], md) -> _TCall: assert md["tag"] == "call" fn = md["fn"] posargs = md["posargs"] or [] kwargs = md["kwargs"] or {} return cls(fn, posargs, kwargs) - def to_metadata(self, posarg_names, kwarg_names): + def to_metadata(self, posarg_names, kwarg_names: list[str]) -> dict: out = {"tag": self.fn} for name, arg in zip(posarg_names, self.posargs): out[name] = arg @@ -52,7 +54,7 @@ def to_metadata(self, posarg_names, kwarg_names): out[name] = self.kwargs[name] return out - def to_call_dict(self): + def to_call_dict(self) -> dict[str, list | str | dict[str, Any]]: return { "tag": "call", "fn": self.fn, @@ -60,7 +62,7 @@ def to_call_dict(self): "kwargs": self.kwargs, } - def to_pytd(self): + def to_pytd(self) -> str: posargs = ", ".join(map(repr, self.posargs)) kwargs = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) if posargs and kwargs: @@ -71,10 +73,12 @@ def to_pytd(self): # Convert some callables to their own specific metadata dicts. # {fn: (posarg_names, kwarg_names)} -_CALLABLES = {"Deprecated": (["reason"], [])} +_CALLABLES: dict[str, tuple[list[str], list[None]]] = { + "Deprecated": (["reason"], []) +} -def to_string(val: Any): +def to_string(val: Any) -> str: return repr(val) @@ -87,7 +91,7 @@ def call_to_annotation(fn, *, posargs=None, kwargs=None): call = Call(fn, posargs or (), kwargs or {}) if fn in _CALLABLES: posarg_names, kwarg_names = _CALLABLES[fn] - out = call.to_metadata(posarg_names, kwarg_names) + out = call.to_metadata(posarg_names, kwarg_names) # pytype: disable=attribute-error else: out = call.to_call_dict() return to_string(out) diff --git a/pytype/pyi/modules.py b/pytype/pyi/modules.py index ba75f9a8d..a5040b25a 100644 --- a/pytype/pyi/modules.py +++ b/pytype/pyi/modules.py @@ -9,7 +9,7 @@ from pytype.pytd import pytd from pytype.pytd.parse import parser_constants -_ParseError = types.ParseError +_ParseError: type[types.ParseError] = types.ParseError @dataclasses.dataclass @@ -21,21 +21,21 @@ class Import: new_name: str qualified_name: str = "" - def pytd_alias(self): + def pytd_alias(self) -> pytd.Alias: return pytd.Alias(self.new_name, self.pytd_node) class Module: """Module and package details.""" - def __init__(self, filename, module_name): + def __init__(self, filename, module_name) -> None: self.filename = filename self.module_name = module_name is_package = file_utils.is_pyi_directory_init(filename) self.package_name = module_utils.get_package_name(module_name, is_package) self.parent_name = module_utils.get_package_name(self.package_name, False) - def _qualify_name_with_special_dir(self, orig_name): + def _qualify_name_with_special_dir(self, orig_name) -> str | None: """Handle the case of '.' and '..' as package names.""" if "__PACKAGE__." in orig_name: # Generated from "from . import foo" - see parser.yy @@ -73,7 +73,7 @@ def qualify_name(self, orig_name): return name return orig_name - def process_import(self, item): + def process_import(self, item) -> Import | None: """Process 'import a, b as c, ...'.""" if isinstance(item, tuple): name, new_name = item @@ -88,7 +88,7 @@ def process_import(self, item): t = pytd.Module(name=as_name, module_name=module_name) return Import(pytd_node=t, name=name, new_name=new_name) - def process_from_import(self, from_package, item): + def process_from_import(self, from_package, item) -> Import: """Process 'from a.b.c import d, ...'.""" if isinstance(item, tuple): name, new_name = item diff --git a/pytype/pyi/parser.py b/pytype/pyi/parser.py index ebb6291f2..d69ad2869 100644 --- a/pytype/pyi/parser.py +++ b/pytype/pyi/parser.py @@ -8,6 +8,7 @@ import re import sys import tokenize +from types import EllipsisType from typing import Any, cast from pytype.ast import debug @@ -24,8 +25,9 @@ from pytype.pytd import visitors from pytype.pytd.codegen import decorate + # reexport as parser.ParseError -ParseError = types.ParseError +ParseError: type[types.ParseError] = types.ParseError # ------------------------------------------------------ # imports @@ -41,7 +43,7 @@ def _import_from_module(module: str | None, level: int) -> str: return prefix + module -def _keyword_to_parseable_name(kw): +def _keyword_to_parseable_name(kw) -> str: return f"__KW_{kw}__" @@ -102,14 +104,14 @@ def _attribute_to_name(node: astlib.Attribute) -> astlib.Name: elif isinstance(val, astlib.Attribute): prefix = _attribute_to_name(val).id elif isinstance(val, (pytd.NamedType, pytd.Module)): - prefix = val.name + prefix = val.name # pytype: disable=attribute-error else: msg = f"Unexpected attribute access on {val!r} [{type(val)}]" raise ParseError(msg) return astlib.Name(f"{prefix}.{node.attr}") -def _read_str_list(name, value): +def _read_str_list(name, value) -> tuple: if not ( isinstance(value, (list, tuple)) and all(types.Pyval.is_str(x) for x in value) @@ -121,10 +123,10 @@ def _read_str_list(name, value): class _ConvertConstantsVisitor(visitor.BaseVisitor): """Converts ast module constants to our own representation.""" - def __init__(self, filename): + def __init__(self, filename) -> None: super().__init__(filename=filename, visit_decorators=True) - def visit_Constant(self, node): + def visit_Constant(self, node) -> EllipsisType | types.Pyval: if node.value is Ellipsis: return definitions.Definitions.ELLIPSIS return types.Pyval.from_const(node) @@ -135,7 +137,7 @@ def visit_UnaryOp(self, node): return node.operand.negated() raise ParseError(f"Unexpected unary operator: {node.op}") - def visit_Assign(self, node): + def visit_Assign(self, node) -> None: if node.type_comment: # Convert the type comment from a raw string to a string constant. node.type_comment = types.Pyval( @@ -146,12 +148,12 @@ def visit_Assign(self, node): class _AnnotationVisitor(visitor.BaseVisitor): """Converts ast type annotations to pytd.""" - def __init__(self, filename, defs): + def __init__(self, filename, defs) -> None: super().__init__(filename=filename) self.defs = defs self.subscripted = [] # Keep track of the name being subscripted. - def show(self, node): + def show(self, node) -> None: print(debug.dump(node, astlib, include_attributes=False)) def _convert_late_annotation(self, annotation): @@ -168,7 +170,7 @@ def _convert_late_annotation(self, annotation): e.clear_position() raise e - def _in_literal(self): + def _in_literal(self) -> bool: if not self.subscripted: return False last = self.subscripted[-1] @@ -187,16 +189,16 @@ def visit_Pyval(self, node): else: raise ParseError(f"Unexpected literal: {node.value!r}") - def visit_Tuple(self, node): + def visit_Tuple(self, node) -> tuple: return tuple(node.elts) - def visit_List(self, node): + def visit_List(self, node) -> list: return list(node.elts) def visit_Name(self, node): return self.defs.new_type(node.id) - def _convert_getattr(self, node): + def _convert_getattr(self, node) -> pytd.NamedType | None: # The protobuf pyi generator outputs getattr(X, 'attr') when 'attr' is a # Python keyword. if node.func.name != "getattr" or len(node.args) != 2: @@ -217,16 +219,16 @@ def visit_Call(self, node): def _get_subscript_params(self, node): return node.slice - def _set_subscript_params(self, node, new_val): + def _set_subscript_params(self, node, new_val) -> None: node.slice = new_val - def _convert_typing_annotated_args(self, node): + def _convert_typing_annotated_args(self, node) -> None: typ, *args = self._get_subscript_params(node).elts typ = self.visit(typ) params = (_MetadataVisitor().visit(x) for x in args) self._set_subscript_params(node, (typ,) + tuple(params)) - def enter_Subscript(self, node): + def enter_Subscript(self, node) -> None: if isinstance(node.value, astlib.Attribute): value = _attribute_to_name(node.value) else: @@ -254,7 +256,7 @@ def visit_Subscript(self, node): params = (params,) return self.defs.new_type(node.value, params) - def leave_Subscript(self, node): + def leave_Subscript(self, node) -> None: self.subscripted.pop() def visit_Attribute(self, node): @@ -279,7 +281,7 @@ def visit_BoolOp(self, node): class _MetadataVisitor(visitor.BaseVisitor): """Converts typing.Annotated metadata.""" - def visit_Call(self, node): + def visit_Call(self, node) -> tuple[Any, tuple, Any]: posargs = tuple(evaluator.literal_eval(x) for x in node.args) kwargs = {x.arg: evaluator.literal_eval(x.value) for x in node.keywords} if isinstance(node.func, astlib.Attribute): @@ -310,20 +312,20 @@ def _flatten_splices(body: list[Any]) -> list[Any]: class Splice: """Splice a list into a node body.""" - def __init__(self, body): + def __init__(self, body) -> None: self.body = _flatten_splices(body) - def __str__(self): + def __str__(self) -> str: return "Splice(\n" + ",\n ".join([str(x) for x in self.body]) + "\n)" - def __repr__(self): + def __repr__(self) -> str: return str(self) class _GeneratePytdVisitor(visitor.BaseVisitor): """Converts an ast tree to a pytd tree.""" - _NOOP_NODES = { + _NOOP_NODES: set[type] = { # Expression contexts are ignored. astlib.Load, astlib.Store, @@ -337,14 +339,19 @@ class _GeneratePytdVisitor(visitor.BaseVisitor): types.Pyval, } - _ANNOT_NODES = ( + _ANNOT_NODES: tuple[ + type[astlib.Attribute], + type[astlib.BinOp], + type[astlib.Name], + type[astlib.Subscript], + ] = ( astlib.Attribute, astlib.BinOp, astlib.Name, astlib.Subscript, ) - def __init__(self, src, filename, module_name, options): + def __init__(self, src, filename, module_name, options) -> None: super().__init__(filename=filename, src_code=src, visit_decorators=True) defs = definitions.Definitions(modules.Module(filename, module_name)) self.defs = defs @@ -355,7 +362,7 @@ def __init__(self, src, filename, module_name, options): self.annotation_visitor = _AnnotationVisitor(filename=filename, defs=defs) self.class_stack = [] - def show(self, node): + def show(self, node) -> None: print(debug.dump(node, astlib, include_attributes=True)) def generic_visit(self, node): @@ -364,11 +371,11 @@ def generic_visit(self, node): return node raise NotImplementedError(f"Unsupported node type: {node_type.__name__}") - def visit_Module(self, node): + def visit_Module(self, node) -> pytd.TypeDeclUnit: node.body = _flatten_splices(node.body) return self.defs.build_type_decl_unit(node.body) - def visit_Pass(self, node): + def visit_Pass(self, node) -> EllipsisType: return self.defs.ELLIPSIS def visit_Expr(self, node): @@ -383,7 +390,9 @@ def visit_Expr(self, node): else: raise ParseError(f"Unexpected expression: {node.value}") - def _extract_function_properties(self, node): + def _extract_function_properties( + self, node + ) -> tuple[list, function.SigProperties]: decorators = [] abstract = coroutine = final = overload = False for d in node.decorator_list: @@ -412,7 +421,7 @@ def _extract_function_properties(self, node): is_async=isinstance(node, astlib.AsyncFunctionDef), ) - def visit_FunctionDef(self, node): + def visit_FunctionDef(self, node) -> function.NameAndSig: node.decorator_list, props = self._extract_function_properties(node) node.body = _flatten_splices(node.body) return function.NameAndSig.from_function(node, props) @@ -423,7 +432,7 @@ def visit_AsyncFunctionDef(self, node): def visit_AnnAssign(self, node): return self._ann_assign(node.target, node.annotation, node.value) - def _ann_assign(self, name, typ, val): + def _ann_assign(self, name, typ, val) -> pytd.Alias | pytd.Constant: is_alias = False if name == "__match_args__" and isinstance(val, tuple): typ = pytd.NamedType("tuple") @@ -485,7 +494,7 @@ def type_of(n): self.defs.add_alias_or_constant(ret) return ret - def visit_AugAssign(self, node): + def visit_AugAssign(self, node) -> Splice: if node.target == "__all__": # Ignore other assignments self.defs.all += _read_str_list(node.target, node.value) @@ -525,7 +534,7 @@ def _bare_assign(self, name, typ, val): self.defs.add_alias_or_constant(ret) return ret - def visit_Assign(self, node): + def visit_Assign(self, node) -> Splice: out = [] value = node.value for target in node.targets: @@ -540,7 +549,7 @@ def visit_Assign(self, node): out.append(self._bare_assign(target, node.type_comment, value)) return Splice(out) - def visit_ClassDef(self, node): + def visit_ClassDef(self, node) -> pytd.Class: full_class_name = ".".join(self.class_stack) self.defs.type_map[full_class_name] = pytd.NamedType(full_class_name) defs = _flatten_splices(node.body) @@ -548,7 +557,7 @@ def visit_ClassDef(self, node): full_class_name, node.bases, node.keywords, node.decorator_list, defs ) - def enter_If(self, node): + def enter_If(self, node) -> None: # Evaluate the test and preemptively remove the invalid branch so we don't # waste time traversing it. node.test = conditions.evaluate(node.test, self.options) @@ -560,7 +569,7 @@ def enter_If(self, node): else: node.body = [] - def visit_If(self, node): + def visit_If(self, node) -> Splice: if not isinstance(node.test, bool): raise ParseError("Unexpected if statement " + debug.dump(node, astlib)) @@ -569,13 +578,13 @@ def visit_If(self, node): else: return Splice(node.orelse) - def visit_Import(self, node): + def visit_Import(self, node) -> Splice: if self.level > 0: raise ParseError("Import statements need to be at module level") self.defs.add_import(None, node.names) return Splice([]) - def visit_ImportFrom(self, node): + def visit_ImportFrom(self, node) -> Splice: if self.level > 0: raise ParseError("Import statements need to be at module level") module = _import_from_module(node.module, node.level) @@ -590,16 +599,16 @@ def visit_alias(self, node): def visit_Name(self, node): return _parseable_name_to_real_name(node.id) - def visit_Attribute(self, node): + def visit_Attribute(self, node) -> str: return f"{node.value}.{node.attr}" - def visit_Tuple(self, node): + def visit_Tuple(self, node) -> tuple: return tuple(node.elts) - def visit_List(self, node): + def visit_List(self, node) -> list: return list(node.elts) - def visit_Dict(self, node): + def visit_Dict(self, node) -> dict: return dict(zip(node.keys, node.values)) def visit_Call(self, node): @@ -647,14 +656,14 @@ def visit_Call(self, node): # List = _Alias() return node.func - def visit_Raise(self, node): + def visit_Raise(self, node) -> types.Raise: return types.Raise(node.exc) # We convert type comments and annotations in enter() because we want to # convert an entire type at once rather than bottom-up. enter() and leave() # are also used to track nesting level. - def _convert_value(self, node): + def _convert_value(self, node) -> None: if isinstance(node.value, self._ANNOT_NODES): node.value = self.annotation_visitor.visit(node.value) elif isinstance(node.value, (astlib.Tuple, astlib.List)): @@ -666,21 +675,21 @@ def _convert_value(self, node): ] node.value = type(node.value)(elts) - def enter_Assign(self, node): + def enter_Assign(self, node) -> None: if node.type_comment: node.type_comment = self.annotation_visitor.visit(node.type_comment) self._convert_value(node) - def enter_AnnAssign(self, node): + def enter_AnnAssign(self, node) -> None: if node.annotation: node.annotation = self.annotation_visitor.visit(node.annotation) self._convert_value(node) - def enter_arg(self, node): + def enter_arg(self, node) -> None: if node.annotation: node.annotation = self.annotation_visitor.visit(node.annotation) - def _convert_list(self, lst, start=0): + def _convert_list(self, lst, start=0) -> None: lst[start:] = [self.annotation_visitor.visit(x) for x in lst[start:]] def _convert_newtype_args(self, node: astlib.Call): @@ -717,11 +726,11 @@ def enter_Call(self, node): elif self.defs.matches_type(func, "typing.NewType"): return self._convert_newtype_args(node) - def enter_Raise(self, node): + def enter_Raise(self, node) -> None: exc = node.exc.func if isinstance(node.exc, astlib.Call) else node.exc node.exc = self.annotation_visitor.visit(exc) - def _convert_decorators(self, node): + def _convert_decorators(self, node) -> None: decorators = [] for d in node.decorator_list: base = d.func if isinstance(d, astlib.Call) else d @@ -734,24 +743,24 @@ def _convert_decorators(self, node): decorators.append(pytd.Alias(name.id, typ)) node.decorator_list = decorators - def enter_FunctionDef(self, node): + def enter_FunctionDef(self, node) -> None: self._convert_decorators(node) if node.returns: node.returns = self.annotation_visitor.visit(node.returns) self.level += 1 self.in_function = True - def leave_FunctionDef(self, node): + def leave_FunctionDef(self, node) -> None: self.level -= 1 self.in_function = False - def enter_AsyncFunctionDef(self, node): + def enter_AsyncFunctionDef(self, node) -> None: self.enter_FunctionDef(node) - def leave_AsyncFunctionDef(self, node): + def leave_AsyncFunctionDef(self, node) -> None: self.leave_FunctionDef(node) - def enter_ClassDef(self, node): + def enter_ClassDef(self, node) -> None: self._convert_decorators(node) node.bases = [ self.annotation_visitor.visit(base) @@ -765,7 +774,7 @@ def enter_ClassDef(self, node): self.level += 1 self.class_stack.append(_parseable_name_to_real_name(node.name)) - def leave_ClassDef(self, node): + def leave_ClassDef(self, node) -> None: self.level -= 1 self.class_stack.pop() @@ -836,7 +845,7 @@ def _is_varname(i): return "\n".join(lines) -def _parse(src: str, feature_version: int, filename: str = ""): +def _parse(src: str, feature_version: int, filename: str = "") -> astlib.Module: """Call the ast parser with the appropriate feature version.""" kwargs = {"feature_version": feature_version, "type_comments": True} try: @@ -865,7 +874,7 @@ def _feature_version(python_version: tuple[int, ...]) -> int: # Options that will be copied from pytype.config.Options. -_TOPLEVEL_PYI_OPTIONS = ( +_TOPLEVEL_PYI_OPTIONS: tuple[str, str, str] = ( "platform", "python_version", "strict_primitive_comparisons", diff --git a/pytype/pyi/parser_test_base.py b/pytype/pyi/parser_test_base.py index 6d137caee..fec92fda5 100644 --- a/pytype/pyi/parser_test_base.py +++ b/pytype/pyi/parser_test_base.py @@ -3,21 +3,27 @@ import re import textwrap +from typing import Any + from pytype.pyi import parser +from pytype.pytd import pytd from pytype.pytd import pytd_utils from pytype.tests import test_base -IGNORE = object() + +IGNORE: Any = object() class ParserTestBase(test_base.UnitTest): """Base class for pyi parsing tests.""" - def setUp(self): + def setUp(self) -> None: super().setUp() self.options = parser.PyiOptions(python_version=self.python_version) - def parse(self, src, name=None, version=None, platform="linux"): + def parse( + self, src, name=None, version=None, platform="linux" + ) -> pytd.TypeDeclUnit: if version: self.options.python_version = version self.options.platform = platform @@ -67,7 +73,7 @@ def check( self.assertMultiLineEqual(expected.rstrip(), actual) return ast - def check_error(self, src, expected_line, message): + def check_error(self, src, expected_line, message) -> None: """Check that parsing the src raises the expected error.""" with self.assertRaises(parser.ParseError) as e: parser.parse_string(textwrap.dedent(src).lstrip(), options=self.options) diff --git a/pytype/pyi/types.py b/pytype/pyi/types.py index a92707ffe..661513f43 100644 --- a/pytype/pyi/types.py +++ b/pytype/pyi/types.py @@ -2,14 +2,17 @@ import ast as astlib import dataclasses -from typing import Any +from typing import Any, TypeVar from pytype.pytd import pytd -_STRING_TYPES = ("str", "bytes", "unicode") +_TParseError = TypeVar("_TParseError", bound="ParseError") +_TPyval = TypeVar("_TPyval", bound="Pyval") +_STRING_TYPES: tuple[str, str, str] = ("str", "bytes", "unicode") -def node_position(node): + +def node_position(node) -> tuple[Any, Any]: # NOTE: ast.Module has no position info, and will be the `node` when # build_type_decl_unit() is called, so we cannot call `node.lineno` return getattr(node, "lineno", None), getattr(node, "col_offset", None) @@ -18,7 +21,9 @@ def node_position(node): class ParseError(Exception): """Exceptions raised by the parser.""" - def __init__(self, msg, line=None, filename=None, column=None, text=None): + def __init__( + self, msg, line=None, filename=None, column=None, text=None + ) -> None: super().__init__(msg) self._line = line self._filename = filename @@ -34,7 +39,9 @@ def from_exc(cls, exc) -> "ParseError": else: return cls(repr(exc)) - def at(self, node, filename=None, src_code=None): + def at( + self: _TParseError, node, filename=None, src_code=None + ) -> _TParseError: """Add position information from `node` if it doesn't already exist.""" if not self._line: self._line, self._column = node_position(node) @@ -47,14 +54,14 @@ def at(self, node, filename=None, src_code=None): pass return self - def clear_position(self): + def clear_position(self) -> None: self._line = None @property def line(self): return self._line - def __str__(self): + def __str__(self) -> str: lines = [] if self._filename or self._line is not None: lines.append(f' File: "{self._filename}", line {self._line}') @@ -98,10 +105,10 @@ class Pyval(astlib.AST): def from_const(cls, node: astlib.Constant): return cls(type(node.value).__name__, node.value, *node_position(node)) - def to_pytd(self): + def to_pytd(self) -> pytd.NamedType: return pytd.NamedType(self.type) - def repr_str(self): + def repr_str(self) -> str: """String representation with prefixes.""" if self.type == "unicode": val = f"u{self.value!r}" @@ -109,7 +116,7 @@ def repr_str(self): val = repr(self.value) return val - def to_pytd_literal(self): + def to_pytd_literal(self) -> pytd.Literal | pytd.NamedType: """Make a pytd node from Literal[self.value].""" if self.type == "NoneType": return pytd.NamedType("NoneType") @@ -121,17 +128,17 @@ def to_pytd_literal(self): val = self.value return pytd.Literal(val) - def negated(self): + def negated(self: _TPyval) -> _TPyval: """Return a new constant with value -self.value.""" if self.type in ("int", "float"): return Pyval(self.type, -self.value, self.lineno, self.col_offset) raise ParseError("Unary `-` can only apply to numeric literals.") @classmethod - def is_str(cls, value): + def is_str(cls, value) -> bool: return isinstance(value, cls) and value.type in _STRING_TYPES - def __repr__(self): + def __repr__(self) -> str: return f"LITERAL({self.repr_str()})" diff --git a/pytype/pyi/visitor.py b/pytype/pyi/visitor.py index 97de139cd..6f3fea28e 100644 --- a/pytype/pyi/visitor.py +++ b/pytype/pyi/visitor.py @@ -1,11 +1,14 @@ """Base visitor for ast parse trees.""" import ast as astlib +from typing import TypeVar from pytype.ast import visitor as ast_visitor from pytype.pyi import types -_ParseError = types.ParseError +_T0 = TypeVar('_T0') + +_ParseError: type[types.ParseError] = types.ParseError class BaseVisitor(ast_visitor.BaseVisitor): @@ -16,7 +19,9 @@ class BaseVisitor(ast_visitor.BaseVisitor): - Has an optional Definitions member """ - def __init__(self, *, filename=None, src_code=None, visit_decorators=False): + def __init__( + self, *, filename=None, src_code=None, visit_decorators=False + ) -> None: super().__init__(astlib, visit_decorators=visit_decorators) self.filename = filename # used for error messages self.src_code = src_code # used for error messages @@ -33,11 +38,11 @@ def visit(self, node): except Exception as e: # pylint: disable=broad-except raise _ParseError.from_exc(e).at(node, self.filename, self.src_code) - def leave(self, node): + def leave(self, node) -> None: try: return super().leave(node) except Exception as e: # pylint: disable=broad-except raise _ParseError.from_exc(e).at(node, self.filename, self.src_code) - def generic_visit(self, node): + def generic_visit(self, node: _T0) -> _T0: return node diff --git a/pytype/pytd/abc_hierarchy.py b/pytype/pytd/abc_hierarchy.py index 9ebf02b8d..80e36588c 100644 --- a/pytype/pytd/abc_hierarchy.py +++ b/pytype/pytd/abc_hierarchy.py @@ -1,10 +1,12 @@ """Hierarchy of abstract base classes, from _collections_abc.py.""" +import collections + from pytype import utils # class -> list of superclasses -SUPERCLASSES = { +SUPERCLASSES: dict[str, list[str]] = { # "mixins" (don't derive from object): "Hashable": [], "Iterable": [], @@ -74,7 +76,7 @@ } -def GetSuperClasses(): +def GetSuperClasses() -> dict[str, list[str]]: """Get a Python type hierarchy mapping. This generates a dictionary that can be used to look up the bases of @@ -88,7 +90,7 @@ def GetSuperClasses(): return SUPERCLASSES.copy() -def GetSubClasses(): +def GetSubClasses() -> collections.defaultdict: """Get a reverse Python type hierarchy mapping. This generates a dictionary that can be used to look up the (known) diff --git a/pytype/pytd/base_visitor.py b/pytype/pytd/base_visitor.py index 068b0b2f6..250aa9fb3 100644 --- a/pytype/pytd/base_visitor.py +++ b/pytype/pytd/base_visitor.py @@ -1,5 +1,6 @@ """Base class for visitors.""" +from collections.abc import Generator import re from typing import Any @@ -9,7 +10,7 @@ # A convenient value for unchecked_node_classnames if a visitor wants to # use unchecked nodes everywhere. -ALL_NODE_NAMES = type( +ALL_NODE_NAMES: Any = type( "contains_everything", (), {"__contains__": lambda *args: True} )() @@ -17,7 +18,7 @@ class _NodeClassInfo: """Representation of a node class in the graph.""" - def __init__(self, cls): + def __init__(self, cls) -> None: self.cls = cls # The class object. self.name = cls.__name__ # The set of NodeClassInfo objects that may appear below this particular @@ -25,7 +26,7 @@ def __init__(self, cls): self.outgoing = set() -def _FindNodeClasses(): +def _FindNodeClasses() -> Generator[_NodeClassInfo, Any, None]: """Yields _NodeClassInfo objects for each node found in pytd.""" for name in dir(pytd): value = getattr(pytd, name) @@ -38,11 +39,11 @@ def _FindNodeClasses(): yield _NodeClassInfo(value) -_IGNORED_TYPES = frozenset([str, bool, int, type(None), Any]) -_ancestor_map = None # Memoized ancestors map. +_IGNORED_TYPES: frozenset = frozenset([str, bool, int, type(None), Any]) +_ancestor_map: Any = None # Memoized ancestors map. -def _GetChildTypes(node_classes, cls: Any): +def _GetChildTypes(node_classes, cls: Any) -> set: """Get all the types that can be in a node's subtree.""" types = set() @@ -132,11 +133,13 @@ class Visitor: old_node: Any visits_all_node_types = False - unchecked_node_names = set() + unchecked_node_names: set = set() - _visitor_functions_cache = {} + _visitor_functions_cache: dict[ + Any, tuple[dict[str, Any], dict[str, Any], dict[str, Any], Any] + ] = {} - def __init__(self): + def __init__(self) -> None: cls = self.__class__ # The set of method names for each visitor implementation is assumed to @@ -217,5 +220,5 @@ def Visit(self, node, *args, **kwargs): self, node, *args, **kwargs ) - def Leave(self, node, *args, **kwargs): + def Leave(self, node, *args, **kwargs) -> None: self.leave_functions[node.__class__.__name__](self, node, *args, **kwargs) diff --git a/pytype/pytd/booleq.py b/pytype/pytd/booleq.py index d9173a252..171ab938a 100644 --- a/pytype/pytd/booleq.py +++ b/pytype/pytd/booleq.py @@ -1,11 +1,16 @@ """Data structures and algorithms for boolean equations.""" import collections +from collections.abc import Callable, Generator import itertools +from typing import Any, TypeVar from pytype.pytd import pytd_utils -chain = itertools.chain.from_iterable +_TFalseValue = TypeVar("_TFalseValue", bound="FalseValue") +_TTrueValue = TypeVar("_TTrueValue", bound="TrueValue") + +chain: Callable = itertools.chain.from_iterable class BooleanTerm: @@ -56,43 +61,43 @@ def extract_equalities(self): class TrueValue(BooleanTerm): """Class for representing "TRUE".""" - def simplify(self, assignments): + def simplify(self: _TTrueValue, assignments) -> _TTrueValue: return self - def __repr__(self): + def __repr__(self) -> str: return "TRUE" - def __str__(self): + def __str__(self) -> str: return "TRUE" - def extract_pivots(self, assignments): + def extract_pivots(self, assignments) -> dict: return {} - def extract_equalities(self): + def extract_equalities(self) -> tuple[()]: return () class FalseValue(BooleanTerm): """Class for representing "FALSE".""" - def simplify(self, assignments): + def simplify(self: _TFalseValue, assignments) -> _TFalseValue: return self - def __repr__(self): + def __repr__(self) -> str: return "FALSE" - def __str__(self): + def __str__(self) -> str: return "FALSE" - def extract_pivots(self, assignments): + def extract_pivots(self, assignments) -> dict[None, None]: return {} - def extract_equalities(self): + def extract_equalities(self) -> tuple[()]: return () -TRUE = TrueValue() -FALSE = FalseValue() +TRUE: TrueValue = TrueValue() +FALSE: FalseValue = FalseValue() def simplify_exprs(exprs, result_type, stop_term, skip_term): @@ -144,7 +149,7 @@ class _Eq(BooleanTerm): __slots__ = ("left", "right") - def __init__(self, left, right): + def __init__(self, left, right) -> None: """Initialize an equality. Args: @@ -154,13 +159,13 @@ def __init__(self, left, right): self.left = left self.right = right - def __repr__(self): + def __repr__(self) -> str: return f"Eq({self.left!r}, {self.right!r})" - def __str__(self): + def __str__(self) -> str: return f"{self.left} == {self.right}" - def __hash__(self): + def __hash__(self) -> int: return hash((self.left, self.right)) def __eq__(self, other): @@ -170,7 +175,7 @@ def __eq__(self, other): and self.right == other.right ) - def __ne__(self, other): + def __ne__(self, other) -> bool: return not self == other def simplify(self, assignments): @@ -193,7 +198,7 @@ def simplify(self, assignments): else: return self if self.right in assignments[self.left] else FALSE - def extract_pivots(self, assignments): + def extract_pivots(self, assignments) -> dict: """Extract the pivots. See BooleanTerm.extract_pivots().""" if self.left in assignments and self.right in assignments: intersection = assignments[self.left] & assignments[self.right] @@ -207,11 +212,11 @@ def extract_pivots(self, assignments): self.right: frozenset((self.left,)), } - def extract_equalities(self): + def extract_equalities(self) -> tuple[tuple[Any, Any]]: return ((self.left, self.right),) -def _expr_set_hash(expr_set): +def _expr_set_hash(expr_set) -> int: # We sort the hash of individual expressions so that two equal sets # have the same hash value. return hash(tuple(sorted(hash(e) for e in expr_set))) @@ -225,7 +230,7 @@ class _And(BooleanTerm): __slots__ = ("exprs",) - def __init__(self, exprs): + def __init__(self, exprs) -> None: """Initialize a conjunction. Args: @@ -236,13 +241,13 @@ def __init__(self, exprs): def __eq__(self, other): return self.__class__ == other.__class__ and self.exprs == other.exprs - def __ne__(self, other): + def __ne__(self, other) -> bool: return not self == other - def __repr__(self): + def __repr__(self) -> str: return f"And({list(self.exprs)!r})" - def __str__(self): + def __str__(self) -> str: return "(" + " & ".join(str(t) for t in self.exprs) + ")" def __hash__(self): @@ -265,7 +270,7 @@ def extract_pivots(self, assignments): pivots[name] = values return {var: values for var, values in pivots.items() if values} - def extract_equalities(self): + def extract_equalities(self) -> tuple: return tuple(chain(expr.extract_equalities() for expr in self.exprs)) @@ -277,7 +282,7 @@ class _Or(BooleanTerm): __slots__ = ("exprs",) - def __init__(self, exprs): + def __init__(self, exprs) -> None: """Initialize a disjunction. Args: @@ -288,13 +293,13 @@ def __init__(self, exprs): def __eq__(self, other): # for unit tests return self.__class__ == other.__class__ and self.exprs == other.exprs - def __ne__(self, other): + def __ne__(self, other) -> bool: return not self == other - def __repr__(self): + def __repr__(self) -> str: return f"Or({list(self.exprs)!r})" - def __str__(self): + def __str__(self) -> str: return "(" + " | ".join(str(t) for t in self.exprs) + ")" def __hash__(self): @@ -305,7 +310,7 @@ def simplify(self, assignments): (e.simplify(assignments) for e in self.exprs), _Or, TRUE, FALSE ) - def extract_pivots(self, assignments): + def extract_pivots(self, assignments) -> dict: """Extract the pivots. See BooleanTerm.extract_pivots().""" pivots = {} # dict of frozenset for expr in self.exprs: @@ -317,7 +322,7 @@ def extract_pivots(self, assignments): pivots[name] = values return pivots - def extract_equalities(self): + def extract_equalities(self) -> tuple: return tuple(chain(expr.extract_equalities() for expr in self.exprs)) @@ -401,13 +406,13 @@ class Solver: ANY_VALUE = "?" - def __init__(self): + def __init__(self) -> None: self.variables = set() self.implications = collections.defaultdict(dict) self.ground_truth = TRUE self.assignments = None - def __str__(self): + def __str__(self) -> str: lines = [] count_false, count_true = 0, 0 if self.ground_truth is not TRUE: @@ -426,7 +431,7 @@ def __str__(self): count_true, ) - def __repr__(self): + def __repr__(self) -> str: lines = [] for var in self.variables: lines.append(f"solver.register_variable({var!r})") @@ -440,7 +445,7 @@ def register_variable(self, variable): """Register a variable. Call before calling solve().""" self.variables.add(variable) - def always_true(self, formula): + def always_true(self, formula) -> None: """Register a ground truth. Call before calling solve().""" assert formula is not FALSE self.ground_truth = And([self.ground_truth, formula]) @@ -457,7 +462,7 @@ def implies(self, e: BooleanTerm, implication: BooleanTerm) -> None: # (ASCII value 126), e.left should always be the variable. self.implications[e.left][e.right] = implication - def _iter_implications(self): + def _iter_implications(self) -> Generator[tuple[Any, Any, Any], Any, None]: for var, value_to_implication in self.implications.items(): for value, implication in value_to_implication.items(): yield (var, value, implication) @@ -469,7 +474,7 @@ def _get_nonfalse_values(self, var): if implication is not FALSE } - def _get_first_approximation(self): + def _get_first_approximation(self) -> dict: """Get all (variable, value) combinations to consider. This gets the (variable, value) combinations that the solver needs to @@ -514,7 +519,7 @@ def _get_first_approximation(self): return value_assignments - def _complete(self): + def _complete(self) -> None: """Insert missing implications. Insert all implications needed to have one implication for every diff --git a/pytype/pytd/codegen/decorate.py b/pytype/pytd/codegen/decorate.py index 2887c37b8..1762fec54 100644 --- a/pytype/pytd/codegen/decorate.py +++ b/pytype/pytd/codegen/decorate.py @@ -1,6 +1,6 @@ """Apply decorators to classes and functions.""" -from collections.abc import Iterable +from collections.abc import Iterable, Callable from pytype.pytd import base_visitor from pytype.pytd import pytd @@ -11,7 +11,7 @@ class ValidateDecoratedClassVisitor(base_visitor.Visitor): """Apply class decorators.""" - def EnterClass(self, cls): + def EnterClass(self, cls) -> None: validate_class(cls) @@ -179,7 +179,7 @@ def validate_class(cls: pytd.Class) -> None: # change to hide that implementation detail. We also add an implicit # "auto_attribs=True" to @attr.s decorators in pyi files. -_DECORATORS = { +_DECORATORS: dict[str, Callable[[pytd.Class], pytd.Class]] = { "dataclasses.dataclass": decorate_dataclass, "attr.s": decorate_attrs, "attr.attrs": decorate_attrs, @@ -188,7 +188,7 @@ def validate_class(cls: pytd.Class) -> None: } -_VALIDATORS = { +_VALIDATORS: dict[str, Callable[[pytd.Class], None]] = { "dataclasses.dataclass": check_class, "attr.s": check_class, "attr.attrs": check_class, diff --git a/pytype/pytd/codegen/function.py b/pytype/pytd/codegen/function.py index 1031578d7..7a2ca68f6 100644 --- a/pytype/pytd/codegen/function.py +++ b/pytype/pytd/codegen/function.py @@ -2,15 +2,20 @@ from collections.abc import Iterable import dataclasses -from typing import Any +from typing import Any, TypeVar from pytype.pytd import pytd +_TNameAndSig = TypeVar("_TNameAndSig", bound="NameAndSig") +_T_DecoratedFunction = TypeVar( + "_T_DecoratedFunction", bound="_DecoratedFunction" +) + class OverloadedDecoratorError(Exception): """Inconsistent decorators on an overloaded function.""" - def __init__(self, name, typ): + def __init__(self, name, typ) -> None: msg = f"Overloaded signatures for '{name}' disagree on {typ} decorators" super().__init__(msg) @@ -18,7 +23,7 @@ def __init__(self, name, typ): class PropertyDecoratorError(Exception): """Inconsistent property decorators on an overloaded function.""" - def __init__(self, name, explanation): + def __init__(self, name, explanation) -> None: msg = f"Invalid property decorators for '{name}': {explanation}" super().__init__(msg) @@ -175,7 +180,7 @@ class _Properties: setter: pytd.Signature | None = None deleter: pytd.Signature | None = None - def set(self, prop, sig, name): + def set(self, prop, sig, name) -> None: assert hasattr(self, prop), prop if getattr(self, prop): msg = (f"need at most one each of @property, @{name}.setter, and " @@ -184,7 +189,7 @@ def set(self, prop, sig, name): setattr(self, prop, sig) -def _has_decorator(fn, dec): +def _has_decorator(fn, dec) -> bool: return any(d.type.name == dec for d in fn.decorators) @@ -202,16 +207,19 @@ class _DecoratedFunction: prop_names: dict[str, _Property] = dataclasses.field(init=False) @classmethod - def make(cls, fn: NameAndSig): + def make( + cls: type[_T_DecoratedFunction], fn: NameAndSig + ) -> _T_DecoratedFunction: return cls( name=fn.name, sigs=[fn.signature], is_abstract=fn.is_abstract, is_coroutine=fn.is_coroutine, is_final=fn.is_final, - decorators=fn.decorators) + decorators=fn.decorators, + ) - def __post_init__(self): + def __post_init__(self) -> None: self.prop_names = _property_decorators(self.name) prop_decorators = [d for d in self.decorators if d.name in self.prop_names] if prop_decorators: @@ -220,7 +228,7 @@ def __post_init__(self): else: self.properties = None - def add_property(self, decorators, sig): + def add_property(self, decorators, sig) -> None: """Add a property overload.""" assert decorators if len(decorators) > 1: diff --git a/pytype/pytd/codegen/namedtuple.py b/pytype/pytd/codegen/namedtuple.py index 89bcd6f75..93e88568e 100644 --- a/pytype/pytd/codegen/namedtuple.py +++ b/pytype/pytd/codegen/namedtuple.py @@ -11,7 +11,7 @@ class NamedTuple: # This is called from the pyi parser, to convert a namedtuple constructed by a # functional constructor into a NamedTuple subclass. - def __init__(self, base_name, fields, generated_classes): + def __init__(self, base_name, fields, generated_classes) -> None: # Handle previously defined NamedTuples with the same name index = len(generated_classes[base_name]) self.name = escape.pack_namedtuple_base_class(base_name, index) diff --git a/pytype/pytd/codegen/pytdgen.py b/pytype/pytd/codegen/pytdgen.py index 970449877..56081a531 100644 --- a/pytype/pytd/codegen/pytdgen.py +++ b/pytype/pytd/codegen/pytdgen.py @@ -4,7 +4,7 @@ from pytype.pytd import pytd -_STRING_TYPES = ("str", "bytes", "unicode") +_STRING_TYPES: tuple[str, str, str] = ("str", "bytes", "unicode") # Type aliases diff --git a/pytype/pytd/mro.py b/pytype/pytd/mro.py index e9ad3147a..a38d9d979 100644 --- a/pytype/pytd/mro.py +++ b/pytype/pytd/mro.py @@ -3,7 +3,7 @@ from pytype.pytd import pytd -def MergeSequences(seqs): +def MergeSequences(seqs) -> list: """Merge a sequence of sequences into a single sequence. This code is copied from https://www.python.org/download/releases/2.3/mro/ @@ -45,7 +45,7 @@ def MergeSequences(seqs): res.append(cand) -def Dedup(seq): +def Dedup(seq) -> list: """Return a sequence in the same order, but with duplicates removed.""" seen = set() result = [] @@ -58,7 +58,7 @@ def Dedup(seq): class MROError(Exception): # pylint: disable=g-bad-exception-name - def __init__(self, seqs): + def __init__(self, seqs) -> None: super().__init__() self.mro_seqs = seqs @@ -121,7 +121,7 @@ def _ComputeMRO(t, mros, lookup_ast): return [t] -def GetBasesInMRO(cls, lookup_ast=None): +def GetBasesInMRO(cls, lookup_ast=None) -> tuple: """Get the given class's bases in Python's method resolution order.""" mros = {} base_mros = [] diff --git a/pytype/pytd/optimize.py b/pytype/pytd/optimize.py index e8f01cf02..f7ba0ff8b 100644 --- a/pytype/pytd/optimize.py +++ b/pytype/pytd/optimize.py @@ -8,6 +8,7 @@ import collections import logging +from typing import Any, TypeVar from pytype import utils from pytype.pytd import abc_hierarchy @@ -17,20 +18,22 @@ from pytype.pytd import pytd_visitors from pytype.pytd import visitors -log = logging.getLogger(__name__) +_T0 = TypeVar("_T0") + +log: logging.Logger = logging.getLogger(__name__) class RenameUnknowns(visitors.Visitor): """Give unknowns that map to the same set of concrete types the same name.""" - def __init__(self, mapping): + def __init__(self, mapping) -> None: super().__init__() self.name_to_cls = {name: hash(cls) for name, cls in mapping.items()} self.cls_to_canonical_name = { cls: name for name, cls in self.name_to_cls.items() } - def VisitClassType(self, node): + def VisitClassType(self, node: _T0) -> pytd.ClassType | _T0: if escape.is_unknown(node.name): return pytd.ClassType( self.cls_to_canonical_name[self.name_to_cls[node.name]], None @@ -50,14 +53,14 @@ class Foo(Generic[T]): def f(self: Foo): ... """ - def __init__(self): + def __init__(self) -> None: super().__init__() self.class_stack = [] - def EnterClass(self, node): + def EnterClass(self, node) -> None: self.class_stack.append(node.name) - def LeaveClass(self, node): + def LeaveClass(self, node) -> None: self.class_stack.pop() def VisitFunction(self, node): @@ -125,11 +128,11 @@ class _ReturnsAndExceptions: exceptions: Exceptions seen so far. """ - def __init__(self): + def __init__(self) -> None: self.return_types = [] self.exceptions = [] - def Update(self, signature): + def Update(self, signature) -> None: """Add the return types / exceptions of a signature to this instance.""" if signature.return_type not in self.return_types: @@ -156,7 +159,7 @@ def f(x: int) -> Union[float, int]: raise OverflowError() """ - def _GroupByArguments(self, signatures): + def _GroupByArguments(self, signatures) -> dict[Any, _ReturnsAndExceptions]: """Groups signatures by arguments. Arguments: @@ -214,7 +217,9 @@ class CombineContainers(visitors.Visitor): . """ - _CONTAINER_NAMES = { + _CONTAINER_NAMES: dict[ + type[pytd.CallableType | pytd.TupleType], tuple[str, ...] + ] = { pytd.TupleType: ("builtins.tuple", "typing.Tuple"), pytd.CallableType: ("typing.Callable",), } @@ -225,7 +230,7 @@ def _key(self, t): else: return t.base_type - def _should_merge(self, pytd_type, union): + def _should_merge(self, pytd_type, union) -> bool: """Determine whether pytd_type values in the union should be merged. If the union contains the homogeneous flavor of pytd_type (e.g., @@ -323,14 +328,14 @@ def VisitUnionType(self, union): class SuperClassHierarchy: """Utility class for optimizations working with superclasses.""" - def __init__(self, superclasses): + def __init__(self, superclasses) -> None: self._superclasses = superclasses self._subclasses = utils.invert_dict(self._superclasses) def GetSuperClasses(self): return self._superclasses - def _CollectSuperclasses(self, type_name, collect): + def _CollectSuperclasses(self, type_name, collect) -> None: """Recursively collect super classes for a type. Arguments: @@ -342,7 +347,7 @@ def _CollectSuperclasses(self, type_name, collect): for superclass in self._superclasses.get(type_name, []): self._CollectSuperclasses(superclass, collect) - def ExpandSuperClasses(self, t): + def ExpandSuperClasses(self, t) -> set[None]: """Generate a list of all (known) superclasses for a type. Arguments: @@ -356,7 +361,7 @@ def ExpandSuperClasses(self, t): self._CollectSuperclasses(t, superclasses) return superclasses - def ExpandSubClasses(self, t): + def ExpandSubClasses(self, t: _T0) -> set[_T0]: """Generate a set of all (known) subclasses for a type. Arguments: @@ -375,11 +380,11 @@ def ExpandSubClasses(self, t): queue.extend(self._subclasses[item]) return seen - def HasSubClassInSet(self, cls, known): + def HasSubClassInSet(self, cls, known) -> bool: """Queries whether a subclass of a type is present in a given set.""" return any(sub in known for sub in self._subclasses[cls]) - def HasSuperClassInSet(self, cls, known): + def HasSuperClassInSet(self, cls, known) -> bool: """Queries whether a superclass of a type is present in a given set.""" return any(sub in known for sub in self._superclasses[cls]) @@ -397,7 +402,7 @@ class SimplifyUnionsWithSuperclasses(visitors.Visitor): A union B = A, if B is a subset of A.) """ - def __init__(self, hierarchy): + def __init__(self, hierarchy) -> None: super().__init__() self.hierarchy = hierarchy @@ -422,7 +427,7 @@ def f(x: Union[list, tuple], y: Union[frozenset, set]) -> Union[int, float] def f(x: Sequence, y: Set) -> Real """ - def __init__(self, hierarchy): + def __init__(self, hierarchy) -> None: super().__init__() self.hierarchy = hierarchy @@ -475,7 +480,7 @@ class CollapseLongUnions(visitors.Visitor): more types than this, it is shortened. """ - def __init__(self, max_length: int = 7): + def __init__(self, max_length: int = 7) -> None: super().__init__() self.generic_type = pytd.AnythingType() self.max_length = max_length @@ -494,12 +499,12 @@ def VisitUnionType(self, union): class AdjustGenericType(visitors.Visitor): """Changes the generic type from "object" to "Any".""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.old_generic_type = pytd.ClassType("builtins.object") self.new_generic_type = pytd.AnythingType() - def VisitClassType(self, t): + def VisitClassType(self, t: _T0) -> pytd.AnythingType | _T0: if t == self.old_generic_type: return self.new_generic_type else: @@ -579,7 +584,7 @@ def m(self, ...) . """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._module = None self._total_count = collections.defaultdict(int) @@ -641,7 +646,7 @@ def _CanDelete(self, cls): return False return self._processed_count[cls.name] == self._total_count[cls.name] - def EnterTypeDeclUnit(self, module): + def EnterTypeDeclUnit(self, module) -> None: # Since modules are hierarchical, we enter TypeDeclUnits multiple times- # but we only want to record the top-level one. if not self._module: @@ -652,11 +657,11 @@ def VisitTypeDeclUnit(self, unit): classes=tuple(c for c in unit.classes if not self._CanDelete(c)) ) - def VisitClassType(self, t): + def VisitClassType(self, t: _T0) -> _T0: self._total_count[t.name] += 1 return t - def VisitNamedType(self, t): + def VisitNamedType(self, t: _T0) -> _T0: self._total_count[t.name] += 1 return t @@ -733,32 +738,32 @@ def VisitGenericType(self, t): class TypeParameterScope(visitors.Visitor): """Common superclass for optimizations that track type parameters.""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.type_params_stack = [{}] - def EnterClass(self, cls): + def EnterClass(self, cls) -> None: new = self.type_params_stack[-1].copy() new.update({t.type_param: cls for t in cls.template}) self.type_params_stack.append(new) - def EnterSignature(self, sig): + def EnterSignature(self, sig) -> None: new = self.type_params_stack[-1].copy() new.update({t.type_param: sig for t in sig.template}) self.type_params_stack.append(new) - def IsClassTypeParameter(self, type_param): + def IsClassTypeParameter(self, type_param) -> bool: class_or_sig = self.type_params_stack[-1].get(type_param) return isinstance(class_or_sig, pytd.Class) - def IsFunctionTypeParameter(self, type_param): + def IsFunctionTypeParameter(self, type_param) -> bool: class_or_sig = self.type_params_stack[-1].get(type_param) return isinstance(class_or_sig, pytd.Signature) - def LeaveClass(self, _): + def LeaveClass(self, _) -> None: self.type_params_stack.pop() - def LeaveSignature(self, _): + def LeaveSignature(self, _) -> None: self.type_params_stack.pop() @@ -786,36 +791,36 @@ def append(self, V:T') -> NoneType mutations to the outermost level (in this example, T' = Union[T, T2]) """ - def __init__(self): + def __init__(self) -> None: super().__init__() self.type_param_union = None - def _AppendNew(self, l1, l2): + def _AppendNew(self, l1, l2) -> None: """Appends all items to l1 that are not in l2.""" # l1 and l2 are small (2-3 elements), so just use two loops. for e2 in l2: if not any(e1 is e2 for e1 in l1): l1.append(e2) - def EnterSignature(self, sig): + def EnterSignature(self, sig) -> None: # Necessary because TypeParameterScope also defines this function super().EnterSignature(sig) assert self.type_param_union is None self.type_param_union = collections.defaultdict(list) - def LeaveSignature(self, node): + def LeaveSignature(self, node) -> None: # Necessary because TypeParameterScope also defines this function super().LeaveSignature(node) self.type_param_union = None - def VisitUnionType(self, u): + def VisitUnionType(self, u: _T0) -> _T0: type_params = [t for t in u.type_list if isinstance(t, pytd.TypeParameter)] for t in type_params: if self.IsFunctionTypeParameter(t): self._AppendNew(self.type_param_union[t.name], type_params) return u - def _AllContaining(self, type_param, seen=None): + def _AllContaining(self, type_param: _T0, seen=None) -> list[_T0]: """Gets all type parameters that are in a union with the passed one.""" seen = seen or set() result = [type_param] @@ -826,7 +831,7 @@ def _AllContaining(self, type_param, seen=None): self._AppendNew(result, self._AllContaining(other, seen) or [other]) return result - def _ReplaceByOuterIfNecessary(self, item, substitutions): + def _ReplaceByOuterIfNecessary(self, item: _T0, substitutions) -> list[_T0]: """Potentially replace a function type param with a class type param. Args: diff --git a/pytype/pytd/parse/parser_constants.py b/pytype/pytd/parse/parser_constants.py index 295ce86cf..f20641425 100644 --- a/pytype/pytd/parse/parser_constants.py +++ b/pytype/pytd/parse/parser_constants.py @@ -3,7 +3,7 @@ import re # PyTD keywords -RESERVED = [ +RESERVED: list[str] = [ 'async', 'class', 'def', @@ -23,7 +23,7 @@ 'TypeVar', ] -RESERVED_PYTHON = [ +RESERVED_PYTHON: list[str] = [ # Python keywords that aren't used by PyTD: 'and', 'assert', @@ -52,11 +52,11 @@ # parser.t_NAME's regex allows a few extra characters in the name. # A less-pedantic RE is r'[-~]'. # See visitors._EscapedName and parser.PyLexer.t_NAME -BACKTICK_NAME = re.compile(r'[-]|^~') +BACKTICK_NAME: re.Pattern[str] = re.compile(r'[-]|^~') # Marks external NamedTypes so that they do not get prefixed by the current # module name. EXTERNAL_NAME_PREFIX = '$external$' # Regex for string literals. -STRING_RE = re.compile("^([bu]?)(('[^']*')|(\"[^\"]*\"))$") +STRING_RE: re.Pattern[str] = re.compile("^([bu]?)(('[^']*')|(\"[^\"]*\"))$") diff --git a/pytype/pytd/parse/parser_test_base.py b/pytype/pytd/parse/parser_test_base.py index 7100acebb..beb255803 100644 --- a/pytype/pytd/parse/parser_test_base.py +++ b/pytype/pytd/parse/parser_test_base.py @@ -18,12 +18,12 @@ class ParserTest(test_base.UnitTest): loader: load_pytd.Loader @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: super().setUpClass() cls.loader = load_pytd.Loader( config.Options.create(python_version=cls.python_version)) - def setUp(self): + def setUp(self) -> None: super().setUp() self.options = parser.PyiOptions(python_version=self.python_version) @@ -60,7 +60,7 @@ def ToAST(self, src_or_tree): src_or_tree.Visit(visitors.VerifyVisitor()) return src_or_tree - def AssertSourceEquals(self, src_or_tree_1, src_or_tree_2): + def AssertSourceEquals(self, src_or_tree_1, src_or_tree_2) -> None: # Strip leading "\n"s for convenience ast1 = self.ToAST(src_or_tree_1) ast2 = self.ToAST(src_or_tree_2) diff --git a/pytype/pytd/pep484.py b/pytype/pytd/pep484.py index bed33d994..5260fcf02 100644 --- a/pytype/pytd/pep484.py +++ b/pytype/pytd/pep484.py @@ -1,10 +1,13 @@ """PEP484 compatibility code.""" +from typing import TypeVar from pytype.pytd import base_visitor from pytype.pytd import pytd +_T0 = TypeVar("_T0") -ALL_TYPING_NAMES = [ + +ALL_TYPING_NAMES: list[str] = [ "AbstractSet", "AnyStr", "AsyncGenerator", @@ -49,7 +52,7 @@ # Pairs of a type and a more generalized type. -_COMPAT_ITEMS = [ +_COMPAT_ITEMS: list[tuple[str, str]] = [ ("int", "float"), ("int", "complex"), ("float", "complex"), @@ -57,10 +60,9 @@ ("memoryview", "bytes"), ] - # The PEP 484 definition of built-in types. # E.g. "typing.List" is used to represent the "list" type. -BUILTIN_TO_TYPING = { +BUILTIN_TO_TYPING: dict[str, str] = { t.lower(): t for t in [ "List", @@ -72,10 +74,10 @@ ] } -TYPING_TO_BUILTIN = {v: k for k, v in BUILTIN_TO_TYPING.items()} +TYPING_TO_BUILTIN: dict[str, str] = {v: k for k, v in BUILTIN_TO_TYPING.items()} -def get_compat_items(none_matches_bool=False): +def get_compat_items(none_matches_bool=False) -> list[tuple[str, str]]: # pep484 allows None as an alias for NoneType in type annotations. extra = [("NoneType", "bool"), ("None", "bool")] if none_matches_bool else [] return _COMPAT_ITEMS + extra @@ -84,7 +86,7 @@ def get_compat_items(none_matches_bool=False): class ConvertTypingToNative(base_visitor.Visitor): """Visitor for converting PEP 484 types to native representation.""" - def __init__(self, module): + def __init__(self, module) -> None: super().__init__() self.module = module @@ -97,7 +99,7 @@ def _GetModuleAndName(self, t): def _IsTyping(self, module): return module == "typing" or (module is None and self.module == "typing") - def _Convert(self, t): + def _Convert(self, t: _T0) -> pytd.AnythingType | pytd.NamedType | _T0: module, name = self._GetModuleAndName(t) if not module and name == "None": # PEP 484 allows "None" as an abbreviation of "NoneType". @@ -120,7 +122,9 @@ def VisitClassType(self, t): def VisitNamedType(self, t): return self._Convert(t) - def VisitGenericType(self, t): + def VisitGenericType( + self, t: _T0 + ) -> pytd.IntersectionType | pytd.UnionType | _T0: module, name = self._GetModuleAndName(t) if self._IsTyping(module): if name == "Intersection": diff --git a/pytype/pytd/printer.py b/pytype/pytd/printer.py index 3130652bc..14815a6f2 100644 --- a/pytype/pytd/printer.py +++ b/pytype/pytd/printer.py @@ -3,6 +3,7 @@ import collections import logging import re +from typing import Any, TypeVar from pytype import utils from pytype.pytd import base_visitor @@ -10,6 +11,11 @@ from pytype.pytd import pytd from pytype.pytd.parse import parser_constants +_NameType: type[str] +_AliasType: type[str] + +_TPrintVisitor = TypeVar("_TPrintVisitor", bound="PrintVisitor") + # Aliases for readability: _NameType = _AliasType = str @@ -17,7 +23,7 @@ class _TypingImports: """Imports from the `typing` module.""" - def __init__(self): + def __init__(self) -> None: # Typing members that are imported via `from typing import ...`. self._members: dict[_AliasType, _NameType] = {} # The number of times that each typing member is used. @@ -28,14 +34,14 @@ def members(self): # Note that when a typing member has multiple aliases, this keeps only one. return {name: alias for alias, name in self._members.items()} - def add(self, name: str, alias: str): + def add(self, name: str, alias: str) -> None: self._counts[name] += 1 self._members[alias] = name - def decrement_count(self, name: str): + def decrement_count(self, name: str) -> None: self._counts[name] -= 1 - def to_import_statements(self): + def to_import_statements(self) -> list[str]: targets = [] for alias, name in self._members.items(): if not self._counts[name]: @@ -50,7 +56,7 @@ def to_import_statements(self): class _Imports: """Imports tracker.""" - def __init__(self): + def __init__(self) -> None: self.track_imports = True self._typing = _TypingImports() self._direct_imports: dict[_AliasType, _NameType] = {} @@ -94,7 +100,7 @@ def add(self, full_name: str, alias: str | None = None): self._from_imports.setdefault(module, {})[alias] = name self._reverse_alias_map[full_name] = alias - def decrement_typing_count(self, member: str): + def decrement_typing_count(self, member: str) -> None: self._typing.decrement_count(member) def get_alias(self, name: str): @@ -102,7 +108,7 @@ def get_alias(self, name: str): return self._typing.members.get(utils.strip_prefix(name, "typing.")) return self._reverse_alias_map.get(name) - def to_import_statements(self): + def to_import_statements(self) -> list: """Converts self to import statements.""" imports = self._typing.to_import_statements() for alias, module in self._direct_imports.items(): @@ -127,14 +133,14 @@ class PrintVisitor(base_visitor.Visitor): """Visitor for converting ASTs back to pytd source code.""" visits_all_node_types = True - unchecked_node_names = base_visitor.ALL_NODE_NAMES + unchecked_node_names: Any = base_visitor.ALL_NODE_NAMES INDENT = " " * 4 - _RESERVED = frozenset( + _RESERVED: frozenset[str] = frozenset( parser_constants.RESERVED + parser_constants.RESERVED_PYTHON ) - def __init__(self, multiline_args=False): + def __init__(self, multiline_args=False) -> None: super().__init__() self.class_names = [] # can contain nested classes self.in_alias = False @@ -155,7 +161,7 @@ def __init__(self, multiline_args=False): def typing_imports(self): return self._imports.typing_members - def copy(self): + def copy(self: _TPrintVisitor) -> _TPrintVisitor: # Note that copy.deepcopy is too slow to use here. copy = PrintVisitor(self.multiline_args) copy.in_alias = self.in_alias @@ -203,7 +209,7 @@ def _LookupTypingMember(self, name): return f"{prefix_alias}.{name}" raise AssertionError("This should never happen.") - def _FormatTypeParams(self, type_params): + def _FormatTypeParams(self, type_params) -> list[str]: formatted_type_params = [] for t in type_params: if t.full_name == "typing.Self": @@ -250,7 +256,7 @@ def _StripUnitPrefix(self, name): else: return name - def _IsAliasImport(self, node): + def _IsAliasImport(self, node) -> bool: if not self._unit or self.in_constant or self.in_signature: return False elif isinstance(node.type, pytd.Module): @@ -262,7 +268,7 @@ def _IsAliasImport(self, node): and "." in node.type.name ) - def _ProcessDecorators(self, node): + def _ProcessDecorators(self, node) -> list: # Our handling of class and function decorators is a bit hacky (see # output.py); this makes sure that typing classes read in directly from a # pyi file and then reemitted (e.g. in assertTypesMatchPytd) have their @@ -274,7 +280,7 @@ def _ProcessDecorators(self, node): self.VisitNamedType(d.type) return utils.unique_list(decorators) - def EnterTypeDeclUnit(self, unit): + def EnterTypeDeclUnit(self, unit) -> None: self._unit = unit definitions = ( unit.classes @@ -298,11 +304,11 @@ def EnterTypeDeclUnit(self, unit): x.name for x in unit.type_params if isinstance(x, pytd.ParamSpec) } - def LeaveTypeDeclUnit(self, _): + def LeaveTypeDeclUnit(self, _) -> None: self._unit = None self._local_names = set() - def VisitTypeDeclUnit(self, node): + def VisitTypeDeclUnit(self, node) -> str: """Convert the AST for an entire module back to a string.""" for t in self.old_node.type_params: if isinstance(t, pytd.ParamSpec): @@ -335,13 +341,13 @@ def VisitTypeDeclUnit(self, node): ) return "\n\n".join(sections_as_string) - def EnterConstant(self, node): + def EnterConstant(self, node) -> None: self.in_constant = True - def LeaveConstant(self, node): + def LeaveConstant(self, node) -> None: self.in_constant = False - def _DropTypingConstant(self, node): + def _DropTypingConstant(self, node) -> bool | None: # Hack to account for a corner case in late annotation handling. # If we have a top-level constant of the exact form # Foo: Type[typing.Foo] @@ -383,11 +389,11 @@ def VisitConstant(self, node): suffix = " = ..." if node.value else "" return f"{node.name}: {node.type}{suffix}" - def EnterAlias(self, node): + def EnterAlias(self, node) -> None: if self.in_function or self._IsAliasImport(node): self._imports.track_imports = False - def LeaveAlias(self, _): + def LeaveAlias(self, _) -> None: self._imports.track_imports = True def VisitAlias(self, node): @@ -400,7 +406,7 @@ def VisitAlias(self, node): return node.type return f"{node.name} = {node.type}" - def EnterClass(self, node): + def EnterClass(self, node) -> set[str]: """Entering a class - record class name for children's use.""" n = node.name if node.template: @@ -414,11 +420,11 @@ def EnterClass(self, node): # of 'Any' when generating a decorator for an InterpreterClass.) return {"decorators"} - def LeaveClass(self, unused_node): + def LeaveClass(self, unused_node) -> None: self._class_members.clear() self.class_names.pop() - def VisitClass(self, node): + def VisitClass(self, node) -> str: """Visit a class, producing a multi-line, properly indented string.""" bases = node.bases if bases == ("TypedDict",): @@ -471,13 +477,13 @@ def VisitClass(self, node): lines = decorators + header + slots + classes + constants + methods return "\n".join(lines) + "\n" - def EnterFunction(self, node): + def EnterFunction(self, node) -> None: self.in_function = True - def LeaveFunction(self, node): + def LeaveFunction(self, node) -> None: self.in_function = False - def VisitFunction(self, node): + def VisitFunction(self, node) -> str: """Visit function, producing multi-line string (one for each signature).""" function_name = node.name if self.old_node.decorators: @@ -524,13 +530,13 @@ def _FormatContainerContents(self, node: pytd.Parameter) -> str: else: return self.Print(node.Replace(type=pytd.AnythingType(), optional=False)) - def EnterSignature(self, node): + def EnterSignature(self, node) -> None: self.in_signature = True - def LeaveSignature(self, node): + def LeaveSignature(self, node) -> None: self.in_signature = False - def VisitSignature(self, node): + def VisitSignature(self, node) -> str: """Visit a signature, producing a string.""" if node.return_type == "nothing": return_type = self._FromTyping("Never") # a prettier alias for nothing @@ -594,15 +600,15 @@ def VisitSignature(self, node): params = ", ".join(params) return f"({params}){ret}:{''.join(body)}" - def EnterParameter(self, unused_node): + def EnterParameter(self, unused_node) -> None: assert not self.in_parameter self.in_parameter = True - def LeaveParameter(self, unused_node): + def LeaveParameter(self, unused_node) -> None: assert self.in_parameter self.in_parameter = False - def _DecrementParameterImports(self, name): + def _DecrementParameterImports(self, name) -> None: if "[" not in name: return param = name.split("[", 1)[-1] @@ -658,7 +664,7 @@ def VisitTemplateItem(self, node): """Convert a template to a string.""" return node.type_param - def _UseExistingModuleAlias(self, name): + def _UseExistingModuleAlias(self, name) -> str | None: prefix, suffix = name.rsplit(".", 1) while prefix: prefix_alias = self._imports.get_alias(prefix) @@ -668,7 +674,7 @@ def _UseExistingModuleAlias(self, name): suffix = f"{remainder}.{suffix}" return None - def _GuessModule(self, maybe_module): + def _GuessModule(self, maybe_module) -> tuple[Any, Any]: """Guess which part of the given name is the module prefix.""" if "." not in maybe_module: return maybe_module, "" @@ -735,7 +741,7 @@ def VisitAnythingType(self, unused_node): """Convert an anything type to a string.""" return self._FromTyping("Any") - def VisitNothingType(self, unused_node): + def VisitNothingType(self, unused_node) -> str: """Convert the nothing type to a string.""" return "nothing" @@ -745,13 +751,13 @@ def VisitTypeParameter(self, node): def VisitParamSpec(self, node): return node.name - def VisitParamSpecArgs(self, node): + def VisitParamSpecArgs(self, node) -> str: return f"{node.name}.args" - def VisitParamSpecKwargs(self, node): + def VisitParamSpecKwargs(self, node) -> str: return f"{node.name}.kwargs" - def VisitModule(self, node): + def VisitModule(self, node) -> str: return "module" def VisitGenericType(self, node): @@ -771,7 +777,7 @@ def VisitGenericType(self, node): parameters = ("...",) + parameters[1:] return node.base_type + "[" + ", ".join(str(p) for p in parameters) + "]" - def VisitCallableType(self, node): + def VisitCallableType(self, node) -> str: typ = node.base_type if len(node.args) == 1 and node.args[0] in self._paramspec_names: return f"{typ}[{node.args[0]}, {node.ret}]" @@ -782,7 +788,7 @@ def VisitCallableType(self, node): args = ", ".join(node.args) return f"{typ}[[{args}], {node.ret}]" - def VisitConcatenate(self, node): + def VisitConcatenate(self, node) -> str: base = self._FromTyping("Concatenate") parameters = ", ".join(node.parameters) return f"{base}[{parameters}]" @@ -800,7 +806,7 @@ def VisitIntersectionType(self, node): type_list = self._FormSetTypeList(node) return self._BuildIntersection(type_list) - def _FormSetTypeList(self, node): + def _FormSetTypeList(self, node) -> dict[Any, None]: """Form list of types within a set type.""" type_list = dict.fromkeys(node.type_list) if self.in_parameter: @@ -860,19 +866,19 @@ def _BuildIntersection(self, type_list): else: return " and ".join(type_list) - def EnterLiteral(self, _): + def EnterLiteral(self, _) -> None: assert not self.in_literal self.in_literal = True - def LeaveLiteral(self, _): + def LeaveLiteral(self, _) -> None: assert self.in_literal self.in_literal = False - def VisitLiteral(self, node): + def VisitLiteral(self, node) -> str: base = self._FromTyping("Literal") return f"{base}[{node.value}]" - def VisitAnnotated(self, node): + def VisitAnnotated(self, node) -> str: base = self._FromTyping("Annotated") annotations = ", ".join(node.annotations) return f"{base}[{node.base_type}, {annotations}]" diff --git a/pytype/pytd/pytd.py b/pytype/pytd/pytd.py index 23be32cfd..2bdba4064 100644 --- a/pytype/pytd/pytd.py +++ b/pytype/pytd/pytd.py @@ -19,12 +19,14 @@ from collections.abc import Generator import enum import itertools -from typing import Any, Union +from typing import Any, TypeVar, Union from pytype.pytd.parse import node +_TMethodFlag = TypeVar('_TMethodFlag', bound='MethodFlag') + # Alias node.Node for convenience. -Node = node.Node +Node: type[Node] = node.Node class Type(Node): @@ -53,7 +55,7 @@ class TypeDeclUnit(Node, eq=False): # in equality or hash operations. _name2item: dict[str, Any] = {} - def _InitCache(self): + def _InitCache(self) -> None: # TODO(b/159053187): Put constants, functions, classes and aliases into a # combined dict. for x in (self.constants, self.functions, self.classes, self.aliases): @@ -86,7 +88,7 @@ def Get(self, name): self._InitCache() return self._name2item.get(name) - def __contains__(self, name): + def __contains__(self, name) -> bool: return bool(self.Get(name)) def IterChildren(self) -> Generator[tuple[str, Any | None], None, None]: @@ -102,7 +104,7 @@ def Replace(self, **kwargs): # The hash/eq/ne values are used for caching and speed things up quite a bit. - def __hash__(self): + def __hash__(self) -> int: return id(self) @@ -165,7 +167,7 @@ class Class(Node): # in equality or hash operations. _name2item: dict[str, Any] = {} - def _InitCache(self): + def _InitCache(self) -> None: # TODO(b/159053187): Put constants, functions, classes and aliases into a # combined dict. for x in (self.methods, self.constants, self.classes): @@ -197,7 +199,7 @@ def Get(self, name): self._InitCache() return self._name2item.get(name) - def __contains__(self, name): + def __contains__(self, name) -> bool: return bool(self.Get(name)) def __hash__(self): @@ -228,10 +230,10 @@ def metaclass(self): class MethodKind(enum.Enum): - METHOD = 'method' - STATICMETHOD = 'staticmethod' - CLASSMETHOD = 'classmethod' - PROPERTY = 'property' + METHOD: Literal['method'] = 'method' + STATICMETHOD: Literal['staticmethod'] = 'staticmethod' + CLASSMETHOD: Literal['classmethod'] = 'classmethod' + PROPERTY: Literal['property'] = 'property' class MethodFlag(enum.Flag): @@ -241,7 +243,7 @@ class MethodFlag(enum.Flag): FINAL = enum.auto() @classmethod - def abstract_flag(cls, is_abstract): # pylint: disable=invalid-name + def abstract_flag(cls: type[_TMethodFlag], is_abstract) -> _TMethodFlag: # pylint: disable=invalid-name # Useful when creating functions directly (other flags aren't needed there). return cls.ABSTRACT if is_abstract else cls.NONE @@ -309,9 +311,9 @@ def has_optional(self): class ParameterKind(enum.Enum): - REGULAR = 'regular' - POSONLY = 'posonly' - KWONLY = 'kwonly' + REGULAR: Literal['regular'] = 'regular' + POSONLY: Literal['posonly'] = 'posonly' + KWONLY: Literal['kwonly'] = 'kwonly' class Parameter(Node): @@ -384,7 +386,7 @@ def upper_value(self): class ParamSpec(TypeParameter): """ParamSpec is a specific case of TypeParameter.""" - def Get(self, attr): + def Get(self, attr) -> ParamSpecArgs | ParamSpecKwargs | None: if attr == 'args': return ParamSpecArgs(self.name) elif attr == 'kwargs': @@ -455,7 +457,7 @@ class NamedType(Type): name: str - def __str__(self): + def __str__(self) -> str: return self.name @@ -481,19 +483,19 @@ def IterChildren(self) -> Generator[tuple[str, Any | None], None, None]: # this, we claim that `name` is the only child. yield 'name', self.name - def __eq__(self, other): + def __eq__(self, other) -> bool: return self.__class__ == other.__class__ and self.name == other.name - def __ne__(self, other): + def __ne__(self, other) -> bool: return not self == other - def __hash__(self): + def __hash__(self) -> int: return hash((self.__class__.__name__, self.name)) - def __str__(self): + def __str__(self) -> str: return str(self.cls.name) if self.cls else self.name - def __repr__(self): + def __repr__(self) -> str: return '{type}{cls}({name})'.format( type=type(self).__name__, name=self.name, @@ -507,14 +509,14 @@ class LateType(Type): name: str recursive: bool = False - def __str__(self): + def __str__(self) -> str: return self.name class AnythingType(Type): """A type we know nothing about yet (? in pytd).""" - def __bool__(self): + def __bool__(self) -> bool: return True @@ -525,7 +527,7 @@ class NothingType(Type): For representing empty lists, and functions that never return. """ - def __bool__(self): + def __bool__(self) -> bool: return True @@ -551,10 +553,10 @@ class _SetOfTypes(Type, frozen=False, eq=False): # parentheses gives the same result. type_list: tuple[TypeU, ...] = () - def __post_init__(self): + def __post_init__(self) -> None: self.type_list = _FlattenTypes(self.type_list) - def __eq__(self, other): + def __eq__(self, other) -> bool: if self is other: return True if isinstance(other, type(self)): @@ -562,10 +564,10 @@ def __eq__(self, other): return frozenset(self.type_list) == frozenset(other.type_list) return NotImplemented - def __ne__(self, other): + def __ne__(self, other) -> bool: return not self == other - def __hash__(self): + def __hash__(self) -> int: return hash(self.type_list) @@ -651,7 +653,10 @@ class Annotated(Type): # Types that can be a base type of GenericType: -GENERIC_BASE_TYPE = (NamedType, ClassType) +GENERIC_BASE_TYPE: tuple[type[NamedType], type[ClassType]] = ( + NamedType, + ClassType, +) # msgspec will not deserialize subclasses. That is, for a class like: @@ -716,7 +721,7 @@ def IsContainer(t: Class) -> bool: # Singleton objects that will be automatically converted to their types. # The unqualified form is there so local name resolution can special-case it. -SINGLETON_TYPES = frozenset({'Ellipsis', 'builtins.Ellipsis'}) +SINGLETON_TYPES: frozenset[str] = frozenset({'Ellipsis', 'builtins.Ellipsis'}) def ToType( diff --git a/pytype/pytd/pytd_utils.py b/pytype/pytd/pytd_utils.py index 666a41abf..2500a0aa4 100644 --- a/pytype/pytd/pytd_utils.py +++ b/pytype/pytd/pytd_utils.py @@ -13,16 +13,19 @@ import collections import itertools import re +from typing import Any, TypeVar from pytype import utils from pytype.pytd import printer from pytype.pytd import pytd from pytype.pytd import pytd_visitors +_T0 = TypeVar("_T0") -ANON_PARAM = re.compile(r"_[0-9]+") -_TUPLE_NAMES = ("builtins.tuple", "typing.Tuple") +ANON_PARAM: re.Pattern[str] = re.compile(r"_[0-9]+") + +_TUPLE_NAMES: tuple[str, str] = ("builtins.tuple", "typing.Tuple") def UnpackUnion(t): @@ -39,7 +42,9 @@ def UnpackGeneric(t, basename): return None -def MakeClassOrContainerType(base_type, type_arguments, homogeneous): +def MakeClassOrContainerType( + base_type: _T0, type_arguments, homogeneous +) -> pytd.GenericType | _T0: """If we have type params, build a generic type, a normal type otherwise.""" if not type_arguments and (homogeneous or base_type.name not in _TUPLE_NAMES): return base_type @@ -54,7 +59,7 @@ def MakeClassOrContainerType(base_type, type_arguments, homogeneous): return container_type(base_type, tuple(type_arguments)) -def Concat(*args, **kwargs): +def Concat(*args, **kwargs) -> pytd.TypeDeclUnit: """Concatenate two or more pytd ASTs.""" assert all(isinstance(arg, pytd.TypeDeclUnit) for arg in args) name = kwargs.get("name") @@ -157,7 +162,7 @@ def CanonicalOrdering(n): return n.Visit(pytd_visitors.CanonicalOrderingVisitor()) -def GetAllSubClasses(ast): +def GetAllSubClasses(ast) -> collections.defaultdict: """Compute a class->subclasses mapping. Args: @@ -178,7 +183,7 @@ def Print(ast, multiline_args=False): return ast.Visit(printer.PrintVisitor(multiline_args)) -def MakeTypeAnnotation(ast, multiline_args=False): +def MakeTypeAnnotation(ast, multiline_args=False) -> tuple[Any, Any]: """Returns a type annotation and any added typing imports.""" vis = printer.PrintVisitor(multiline_args) annotation = ast.Visit(vis) @@ -192,7 +197,7 @@ def CreateModule(name="", **kwargs): return module.Replace(**kwargs) -def WrapTypeDeclUnit(name, items): +def WrapTypeDeclUnit(name, items) -> pytd.TypeDeclUnit: """Given a list (classes, functions, etc.), wrap a pytd around them. Args: @@ -266,7 +271,7 @@ def WrapTypeDeclUnit(name, items): ) -def _check_intersection(items1, items2, name1, name2): +def _check_intersection(items1, items2, name1, name2) -> None: """Check for duplicate identifiers.""" items = set(items1) & set(items2) if items: @@ -292,18 +297,18 @@ def _check_intersection(items1, items2, name1, name2): class TypeBuilder: """Utility class for building union types.""" - def __init__(self): + def __init__(self) -> None: self.union = pytd.NothingType() self.tags = set() - def add_type(self, other): + def add_type(self, other) -> None: """Add a new pytd type to the types represented by this TypeBuilder.""" if isinstance(other, pytd.Annotated): self.tags.update(other.annotations) other = other.base_type self.union = JoinTypes([self.union, other]) - def wrap(self, base): + def wrap(self, base) -> None: """Wrap the type in a generic type.""" self.union = pytd.GenericType( base_type=pytd.NamedType(base), parameters=(self.union,) @@ -316,14 +321,14 @@ def build(self): else: return self.union - def __bool__(self): + def __bool__(self) -> bool: return not isinstance(self.union, pytd.NothingType) # For running under Python 2 __nonzero__ = __bool__ -def NamedOrClassType(name, cls): +def NamedOrClassType(name, cls) -> pytd.ClassType | pytd.NamedType: """Create Classtype / NamedType.""" if cls is None: return pytd.NamedType(name) @@ -331,7 +336,7 @@ def NamedOrClassType(name, cls): return pytd.ClassType(name, cls) -def NamedTypeWithModule(name, module=None): +def NamedTypeWithModule(name, module=None) -> pytd.NamedType: """Create NamedType, dotted if we have a module.""" if module is None: return pytd.NamedType(name) @@ -342,10 +347,10 @@ def NamedTypeWithModule(name, module=None): class OrderedSet(dict): """A simple ordered set.""" - def __init__(self, iterable=None): + def __init__(self, iterable=None) -> None: super().__init__((item, None) for item in (iterable or [])) - def add(self, item): + def add(self, item) -> None: self[item] = None @@ -361,13 +366,13 @@ def ASTeq(ast1: pytd.TypeDeclUnit, ast2: pytd.TypeDeclUnit): ) -def GetTypeParameters(node): +def GetTypeParameters(node) -> list: collector = pytd_visitors.CollectTypeParameters() node.Visit(collector) return collector.params -def DummyMethod(name, *params): +def DummyMethod(name, *params) -> pytd.Function: """Create a simple method using only "Any"s as types. Arguments: @@ -403,7 +408,7 @@ def make_param(param): ) -def MergeBaseClass(cls, base): +def MergeBaseClass(cls, base) -> pytd.Class: """Merge a base class into a subclass. Arguments: diff --git a/pytype/pytd/pytd_visitors.py b/pytype/pytd/pytd_visitors.py index e0956a344..ca06f92fd 100644 --- a/pytype/pytd/pytd_visitors.py +++ b/pytype/pytd/pytd_visitors.py @@ -5,9 +5,12 @@ like to use, feel free to propose moving it here. """ +from typing import Any, TypeVar from pytype.pytd import base_visitor from pytype.pytd import pytd +_T0 = TypeVar("_T0") + # TODO(rechen): IsNamedTuple is being used to disable visitors that shouldn't # operate on generated classes. Should we do the same for dataclasses and attrs? @@ -25,7 +28,7 @@ class CanonicalOrderingVisitor(base_visitor.Visitor): as the signature order determines lookup order. """ - def VisitTypeDeclUnit(self, node): + def VisitTypeDeclUnit(self, node) -> pytd.TypeDeclUnit: return pytd.TypeDeclUnit( name=node.name, constants=tuple(sorted(node.constants)), @@ -45,7 +48,7 @@ def _PreserveConstantsOrdering(self, node): # The order of a namedtuple's fields should always be preserved. return IsNamedTuple(node) - def VisitClass(self, node): + def VisitClass(self, node) -> pytd.Class: if self._PreserveConstantsOrdering(node): constants = node.constants else: @@ -68,31 +71,31 @@ def VisitSignature(self, node): exceptions=tuple(sorted(node.exceptions)), ) - def VisitUnionType(self, node): + def VisitUnionType(self, node) -> pytd.UnionType: return pytd.UnionType(tuple(sorted(node.type_list))) class ClassTypeToNamedType(base_visitor.Visitor): """Change all ClassType objects to NameType objects.""" - def VisitClassType(self, node): + def VisitClassType(self, node) -> pytd.NamedType: return pytd.NamedType(node.name) class CollectTypeParameters(base_visitor.Visitor): """Visitor that accumulates type parameters in its "params" attribute.""" - def __init__(self): + def __init__(self) -> None: super().__init__() self._seen = set() self.params = [] - def EnterTypeParameter(self, p): + def EnterTypeParameter(self, p) -> None: if p.name not in self._seen: self.params.append(p) self._seen.add(p.name) - def EnterParamSpec(self, p): + def EnterParamSpec(self, p) -> None: self.EnterTypeParameter(p) @@ -103,19 +106,19 @@ class ExtractSuperClasses(base_visitor.Visitor): to lists of pytd.Type. """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._superclasses = {} - def _Key(self, node): + def _Key(self, node: _T0) -> _T0: # This method should be implemented by subclasses. return node - def VisitTypeDeclUnit(self, module): + def VisitTypeDeclUnit(self, module) -> dict[Any, list]: del module return self._superclasses - def EnterClass(self, cls): + def EnterClass(self, cls) -> None: bases = [] for p in cls.bases: base = self._Key(p) @@ -127,7 +130,7 @@ def EnterClass(self, cls): class RenameModuleVisitor(base_visitor.Visitor): """Renames a TypeDeclUnit.""" - def __init__(self, old_module_name, new_module_name): + def __init__(self, old_module_name, new_module_name) -> None: """Constructor. Args: @@ -176,7 +179,7 @@ def _ReplaceModuleName(self, node): else: return node - def VisitClassType(self, node): + def VisitClassType(self, node: _T0) -> pytd.ClassType | _T0: new_name = self._MaybeNewName(node.name) if new_name != node.name: return pytd.ClassType(new_name, node.cls) diff --git a/pytype/pytd/serialize_ast.py b/pytype/pytd/serialize_ast.py index 7fbafd024..e70b06f36 100644 --- a/pytype/pytd/serialize_ast.py +++ b/pytype/pytd/serialize_ast.py @@ -6,6 +6,7 @@ """ +from typing import TypeVar import msgspec from pytype import utils from pytype.pyi import parser @@ -13,6 +14,8 @@ from pytype.pytd import pytd_utils from pytype.pytd import visitors +_TSerializableAst = TypeVar("_TSerializableAst", bound="SerializableAst") + class UnrestorableDependencyError(Exception): """If a dependency can't be restored in the current state.""" @@ -21,11 +24,11 @@ class UnrestorableDependencyError(Exception): class FindClassTypesVisitor(visitors.Visitor): """Visitor to find class and function types.""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.class_type_nodes = [] - def EnterClassType(self, n): + def EnterClassType(self, n) -> None: self.class_type_nodes.append(n) @@ -36,11 +39,11 @@ class UndoModuleAliasesVisitor(visitors.Visitor): names of modules, not whatever they've been aliased to in the current module. """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._module_aliases = {} - def EnterTypeDeclUnit(self, node): + def EnterTypeDeclUnit(self, node) -> None: for alias in node.aliases: if isinstance(alias.type, pytd.Module): name = utils.strip_prefix(alias.name, f"{node.name}.") @@ -66,10 +69,10 @@ class ClearLookupCache(visitors.Visitor): (https://github.com/jcrist/msgspec/issues/199) """ - def LeaveClass(self, node): + def LeaveClass(self, node) -> None: node._name2item.clear() # pylint: disable=protected-access - def LeaveTypeDeclUnit(self, node): + def LeaveTypeDeclUnit(self, node) -> None: node._name2item.clear() # pylint: disable=protected-access @@ -101,7 +104,7 @@ class SerializableAst(msgspec.Struct): metadata: list[str] class_type_nodes: list[pytd.ClassType] | None = None - def __post_init__(self): + def __post_init__(self) -> None: # TODO(tsudol): I do not believe we actually use self.class_type_nodes for # anything besides filling in pointers. That is, it's ALWAYS the list of ALL # ClassType nodes in the AST. So the attribute doesn't need to exist. @@ -117,7 +120,7 @@ def __post_init__(self): else: self.class_type_nodes = indexer.class_type_nodes - def Replace(self, **kwargs): + def Replace(self: _TSerializableAst, **kwargs) -> _TSerializableAst: return msgspec.structs.replace(self, **kwargs) diff --git a/pytype/pytd/slots.py b/pytype/pytd/slots.py index 6f5b19062..3442b4970 100644 --- a/pytype/pytd/slots.py +++ b/pytype/pytd/slots.py @@ -5,7 +5,9 @@ mappings. """ +from collections.abc import Callable import dataclasses +from typing import Any TYPEOBJECT_PREFIX = "tp_" NUMBER_PREFIX = "nb_" @@ -397,11 +399,13 @@ class Slot: CMP_EXC_MATCH = 10 -CMP_ALWAYS_SUPPORTED = frozenset({CMP_EQ, CMP_NE, CMP_IS, CMP_IS_NOT}) +CMP_ALWAYS_SUPPORTED: frozenset[int] = frozenset( + {CMP_EQ, CMP_NE, CMP_IS, CMP_IS_NOT} +) EQ, NE, LT, LE, GT, GE = "==", "!=", "<", "<=", ">", ">=" -COMPARES = { +COMPARES: dict[str, Callable[[Any, Any], Any]] = { EQ: lambda x, y: x == y, NE: lambda x, y: x != y, LT: lambda x, y: x < y, @@ -411,12 +415,12 @@ class Slot: } -SYMBOL_MAPPING = { +SYMBOL_MAPPING: dict[str, str] = { slot.python_name: slot.symbol for slot in SLOTS if slot.symbol } -def _ReverseNameMapping(): +def _ReverseNameMapping() -> dict[str, str]: """__add__ -> __radd__, __mul__ -> __rmul__ etc.""" c_name_to_reverse = { slot.c_name: slot.python_name for slot in SLOTS if slot.index == 1 @@ -428,4 +432,4 @@ def _ReverseNameMapping(): } -REVERSE_NAME_MAPPING = _ReverseNameMapping() +REVERSE_NAME_MAPPING: dict[str, str] = _ReverseNameMapping() diff --git a/pytype/pytd/type_match.py b/pytype/pytd/type_match.py index e741f820d..b84a9f8e1 100644 --- a/pytype/pytd/type_match.py +++ b/pytype/pytd/type_match.py @@ -11,7 +11,7 @@ """ import logging -from typing import Optional, Union +from typing import Any, TypeVar, Optional, Union from pytype import utils from pytype.pytd import booleq @@ -21,7 +21,9 @@ from pytype.pytd import visitors from pytype.pytd.parse import node -log = logging.getLogger(__name__) +_T0 = TypeVar("_T0") + +log: logging.Logger = logging.getLogger(__name__) is_complete = escape.is_complete @@ -32,7 +34,7 @@ # Might not be needed anymore once pytd has builtin support for ~unknown. -def is_unknown(t): +def is_unknown(t) -> bool: """Return True if this is an ~unknown.""" if isinstance(t, (pytd.ClassType, pytd.NamedType, pytd.Class, StrictType)): return escape.is_unknown(t.name) @@ -42,7 +44,7 @@ def is_unknown(t): return False -def get_all_subclasses(asts): +def get_all_subclasses(asts) -> dict: """Compute a class->subclasses mapping. Args: @@ -80,14 +82,14 @@ class StrictType(node.Node): name: str - def __str__(self): + def __str__(self) -> str: return self.name class TypeMatch(pytd_utils.TypeMatcher): """Class for matching types against other types.""" - def __init__(self, direct_subclasses=None, any_also_is_bottom=True): + def __init__(self, direct_subclasses=None, any_also_is_bottom=True) -> None: """Construct. Args: @@ -178,7 +180,7 @@ def type_parameter( # not for matching of "known" types against each other. return StrictType(name) - def _get_parameters(self, t1, t2): + def _get_parameters(self, t1, t2) -> tuple[Any, Any]: if isinstance(t1, pytd.TupleType) and isinstance(t2, pytd.TupleType): # No change needed; the parameters will be compared element-wise. return t1.parameters, t2.parameters @@ -269,7 +271,7 @@ def match_Unknown_against_Generic( # pylint: disable=invalid-name ] return booleq.And([base_match] + params) - def match_Generic_against_Unknown(self, t1, t2, subst): # pylint: disable=invalid-name + def match_Generic_against_Unknown(self, t1, t2, subst) -> booleq.BooleanTerm: # pylint: disable=invalid-name # Note: This flips p1 and p2 above. return self.match_Unknown_against_Generic(t2, t1, subst) # pylint: disable=arguments-out-of-order @@ -286,7 +288,7 @@ def maybe_lookup_type_param(self, t, subst): t = subst[t] return t - def unclass(self, t): + def unclass(self, t: _T0) -> pytd.NamedType | _T0: """Prevent further subclass or superclass expansion for this type.""" if isinstance(t, pytd.ClassType): # When t.name and t.cls.name differ (e.g., int vs. builtins.int), the diff --git a/pytype/pytd/visitors.py b/pytype/pytd/visitors.py index d19a4305c..e27da9be2 100644 --- a/pytype/pytd/visitors.py +++ b/pytype/pytd/visitors.py @@ -1,11 +1,11 @@ """Visitor(s) for walking ASTs.""" import collections -from collections.abc import Callable +from collections.abc import Callable, Generator import itertools import logging import re -from typing import TypeVar, cast +from typing import Any, TypeVar, cast from pytype import datatypes from pytype import module_utils @@ -19,6 +19,8 @@ from pytype.pytd import pytd_visitors from pytype.pytd.parse import parser_constants # pylint: disable=g-importing-member +_T0 = TypeVar('_T0') + _N = TypeVar("_N", bound=pytd.Node) _T = TypeVar("_T", bound=pytd.Type) @@ -38,21 +40,31 @@ class LiteralValueError(Exception): class MissingModuleError(KeyError): - def __init__(self, module): + def __init__(self, module) -> None: self.module = module super().__init__(f"Unknown module {module}") # All public elements of pytd_visitors are aliased here so that we can maintain # the conceptually simpler illusion of having a single visitors module. -ALL_NODE_NAMES = base_visitor.ALL_NODE_NAMES -Visitor = base_visitor.Visitor -CanonicalOrderingVisitor = pytd_visitors.CanonicalOrderingVisitor -ClassTypeToNamedType = pytd_visitors.ClassTypeToNamedType -CollectTypeParameters = pytd_visitors.CollectTypeParameters -ExtractSuperClasses = pytd_visitors.ExtractSuperClasses -PrintVisitor = printer.PrintVisitor -RenameModuleVisitor = pytd_visitors.RenameModuleVisitor +ALL_NODE_NAMES: Any = base_visitor.ALL_NODE_NAMES +Visitor: type[base_visitor.Visitor] = base_visitor.Visitor +CanonicalOrderingVisitor: type[pytd_visitors.CanonicalOrderingVisitor] = ( + pytd_visitors.CanonicalOrderingVisitor +) +ClassTypeToNamedType: type[pytd_visitors.ClassTypeToNamedType] = ( + pytd_visitors.ClassTypeToNamedType +) +CollectTypeParameters: type[pytd_visitors.CollectTypeParameters] = ( + pytd_visitors.CollectTypeParameters +) +ExtractSuperClasses: type[pytd_visitors.ExtractSuperClasses] = ( + pytd_visitors.ExtractSuperClasses +) +PrintVisitor: type[printer.PrintVisitor] = printer.PrintVisitor +RenameModuleVisitor: type[pytd_visitors.RenameModuleVisitor] = ( + pytd_visitors.RenameModuleVisitor +) class FillInLocalPointers(Visitor): @@ -62,7 +74,7 @@ class FillInLocalPointers(Visitor): necessary because we introduce loops. """ - def __init__(self, lookup_map, fallback=None): + def __init__(self, lookup_map, fallback=None) -> None: """Create this visitor. You're expected to then pass this instance to node.Visit(). @@ -77,7 +89,7 @@ def __init__(self, lookup_map, fallback=None): lookup_map["*"] = fallback self._lookup_map = lookup_map - def _Lookup(self, node): + def _Lookup(self, node) -> Generator[tuple[str, Any], Any, None]: """Look up a node by name.""" if "." in node.name: modules_to_try = [] @@ -103,7 +115,7 @@ def _Lookup(self, node): else: yield prefix, item - def EnterClassType(self, node): + def EnterClassType(self, node) -> None: """Fills in a class type. Args: @@ -145,7 +157,7 @@ def EnterClassType(self, node): class _RemoveTypeParametersFromGenericAny(Visitor): """Adjusts GenericType nodes to handle base type changes.""" - unchecked_node_names = ("GenericType",) + unchecked_node_names: tuple[str] = ("GenericType",) def VisitGenericType(self, node): if isinstance(node.base_type, (pytd.AnythingType, pytd.Constant)): @@ -159,7 +171,7 @@ def VisitGenericType(self, node): class DefaceUnresolved(_RemoveTypeParametersFromGenericAny): """Replace all types not in a symbol table with AnythingType.""" - def __init__(self, lookup_list, do_not_log_prefix=None): + def __init__(self, lookup_list, do_not_log_prefix=None) -> None: """Create this visitor. Args: @@ -172,7 +184,7 @@ def __init__(self, lookup_list, do_not_log_prefix=None): self._lookup_list = lookup_list self._do_not_log_prefix = do_not_log_prefix - def VisitNamedType(self, node): + def VisitNamedType(self, node: _T0) -> pytd.AnythingType | _T0: """Do replacement on a pytd.NamedType.""" name = node.name for lookup in self._lookup_list: @@ -204,7 +216,7 @@ def VisitClassType(self, node): class NamedTypeToClassType(Visitor): """Change all NamedType objects to ClassType objects.""" - def VisitNamedType(self, node): + def VisitNamedType(self, node) -> pytd.ClassType: """Converts a named type to a class type, to be filled in later. Args: @@ -247,18 +259,18 @@ def LookupClasses(target, global_module=None, ignore_late_types=False): class VerifyLookup(Visitor): """Utility class for testing visitors.LookupClasses.""" - def __init__(self, ignore_late_types=False): + def __init__(self, ignore_late_types=False) -> None: super().__init__() self.ignore_late_types = ignore_late_types - def EnterLateType(self, node): + def EnterLateType(self, node) -> None: if not self.ignore_late_types: raise ValueError(f"Unresolved LateType: {node.name!r}") def EnterNamedType(self, node): raise ValueError(f"Unreplaced NamedType: {node.name!r}") - def EnterClassType(self, node): + def EnterClassType(self, node) -> None: if node.cls is None: raise ValueError(f"Unresolved class: {node.name!r}") @@ -273,23 +285,23 @@ class _ToTypeVisitor(Visitor): appropriate allow_constants and allow_functions values. """ - def __init__(self, allow_singletons): + def __init__(self, allow_singletons) -> None: super().__init__() self._in_alias = 0 self._in_literal = 0 self.allow_singletons = allow_singletons self.allow_functions = False - def EnterAlias(self, node): + def EnterAlias(self, node) -> None: self._in_alias += 1 - def LeaveAlias(self, _): + def LeaveAlias(self, _) -> None: self._in_alias -= 1 - def EnterLiteral(self, _): + def EnterLiteral(self, _) -> None: self._in_literal += 1 - def LeaveLiteral(self, _): + def LeaveLiteral(self, _) -> None: self._in_literal -= 1 def to_type(self, t): @@ -306,7 +318,7 @@ def to_type(self, t): class LookupBuiltins(_ToTypeVisitor): """Look up built-in NamedTypes and give them fully-qualified names.""" - def __init__(self, builtins, full_names=True, allow_singletons=False): + def __init__(self, builtins, full_names=True, allow_singletons=False) -> None: """Create this visitor. Args: @@ -318,11 +330,11 @@ def __init__(self, builtins, full_names=True, allow_singletons=False): self._builtins = builtins self._full_names = full_names - def EnterTypeDeclUnit(self, unit): + def EnterTypeDeclUnit(self, unit) -> None: self._current_unit = unit self._prefix = unit.name + "." if self._full_names else "" - def LeaveTypeDeclUnit(self, _): + def LeaveTypeDeclUnit(self, _) -> None: del self._current_unit del self._prefix @@ -374,7 +386,7 @@ def _MaybeSubstituteParametersInGenericType(node): class LookupExternalTypes(_RemoveTypeParametersFromGenericAny, _ToTypeVisitor): """Look up NamedType pointers using a symbol table.""" - def __init__(self, module_map, self_name=None, module_alias_map=None): + def __init__(self, module_map, self_name=None, module_alias_map=None) -> None: """Create this visitor. Args: @@ -426,21 +438,21 @@ def _ResolveUsingStarImport(self, module, name): return imported_alias return None - def EnterAlias(self, node): + def EnterAlias(self, node) -> None: super().EnterAlias(node) self._alias_names.append(node.name) - def LeaveAlias(self, node): + def LeaveAlias(self, node) -> None: super().LeaveAlias(node) self._alias_names.pop() - def EnterGenericType(self, _): + def EnterGenericType(self, _) -> None: self._in_generic_type += 1 - def LeaveGenericType(self, _): + def LeaveGenericType(self, _) -> None: self._in_generic_type -= 1 - def _LookupModuleRecursive(self, name): + def _LookupModuleRecursive(self, name) -> tuple[Any, Any]: module_name, cls_prefix = name, "" while module_name not in self._module_map and "." in module_name: module_name, class_name = module_name.rsplit(".", 1) @@ -450,7 +462,7 @@ def _LookupModuleRecursive(self, name): else: raise MissingModuleError(name) - def _IsLocalName(self, prefix): + def _IsLocalName(self, prefix) -> bool: if prefix == self.name: return True if not self._unit: @@ -585,7 +597,7 @@ def VisitGenericType(self, node): def _ModulePrefix(self): return self.name + "." if self.name else "" - def _ImportAll(self, module): + def _ImportAll(self, module) -> tuple[list[pytd.Alias], set]: """Get the new members that would result from a star import of the module. Args: @@ -643,7 +655,7 @@ def _ImportAll(self, module): aliases.append(pytd.Alias(new_name, t)) return aliases, getattrs - def _DiscardExistingNames(self, node, potential_members): + def _DiscardExistingNames(self, node, potential_members) -> list: new_members = [] for m in potential_members: if m.name not in node: @@ -677,7 +689,7 @@ def _EquivalentAliases(self, alias1, alias2) -> bool: return True return self._ResolveAlias(alias1) == self._ResolveAlias(alias2) - def _HandleDuplicates(self, new_aliases): + def _HandleDuplicates(self, new_aliases) -> list: """Handle duplicate module-level aliases. Aliases pointing to qualified names could be the result of importing the @@ -710,7 +722,7 @@ def _HandleDuplicates(self, new_aliases): ) return out - def EnterTypeDeclUnit(self, node): + def EnterTypeDeclUnit(self, node) -> None: self._unit = node def VisitTypeDeclUnit(self, node): @@ -760,25 +772,25 @@ def VisitTypeDeclUnit(self, node): class LookupLocalTypes(_RemoveTypeParametersFromGenericAny, _ToTypeVisitor): """Look up local identifiers. Must be called on a TypeDeclUnit.""" - def __init__(self, allow_singletons=False, toplevel=True): + def __init__(self, allow_singletons=False, toplevel=True) -> None: super().__init__(allow_singletons) self._toplevel = toplevel self.local_names = set() self.class_names = [] - def EnterTypeDeclUnit(self, unit): + def EnterTypeDeclUnit(self, unit) -> None: self.unit = unit - def LeaveTypeDeclUnit(self, _): + def LeaveTypeDeclUnit(self, _) -> None: del self.unit def _LookupItemRecursive(self, name: str) -> pytd.Node: return pytd.LookupItemRecursive(self.unit, name) - def EnterClass(self, node): + def EnterClass(self, node) -> None: self.class_names.append(node.name) - def LeaveClass(self, unused_node): + def LeaveClass(self, unused_node) -> None: self.class_names.pop() def _LookupScopedName(self, name: str) -> pytd.Node | None: @@ -819,7 +831,7 @@ def _LookupLocalName(self, node: pytd.Node) -> pytd.Node: raise SymbolLookupError(msg) return item - def _LookupLocalTypes(self, node): + def _LookupLocalTypes(self, node) -> tuple[Any, set[None]]: visitor = LookupLocalTypes(self.allow_singletons, toplevel=False) visitor.unit = self.unit return node.Visit(visitor), visitor.local_names @@ -891,7 +903,7 @@ def VisitNamedType(self, node): return self._LookupLocalTypes(resolved_node)[0] return resolved_node - def VisitClassType(self, t): + def VisitClassType(self, t: _T0) -> _T0: if not t.cls: if t.name == self.class_names[-1]: full_name = ".".join(self.class_names) @@ -915,7 +927,7 @@ class ReplaceTypesByName(Visitor): mapping. The two cases are not distinguished. """ - def __init__(self, mapping, record=None): + def __init__(self, mapping, record=None) -> None: """Initialize this visitor. Args: @@ -955,7 +967,7 @@ def __init__( self._matcher = matcher self._replacement = replacement - def VisitNamedType(self, node): + def VisitNamedType(self, node: _T0) -> pytd.Node | _T0: return self._replacement if self._matcher(node) else node def VisitClassType(self, node): @@ -980,7 +992,7 @@ def _Key(self, node): class ReplaceTypeParameters(Visitor): """Visitor for replacing type parameters with actual types.""" - def __init__(self, mapping): + def __init__(self, mapping) -> None: super().__init__() self.mapping = mapping @@ -988,7 +1000,7 @@ def VisitTypeParameter(self, p): return self.mapping[p] -def ClassAsType(cls): +def ClassAsType(cls) -> pytd.GenericType | pytd.NamedType: """Converts a pytd.Class to an instance of pytd.Type.""" params = tuple(item.type_param for item in cls.template) if not params: @@ -1011,27 +1023,27 @@ def f(self: A) first argument to just "self") """ - def __init__(self, force=False): + def __init__(self, force=False) -> None: super().__init__() self.class_types = [] # allow nested classes self.force = force self.method_kind = None - def EnterClass(self, cls): + def EnterClass(self, cls) -> None: self.class_types.append(ClassAsType(cls)) - def LeaveClass(self, unused_node): + def LeaveClass(self, unused_node) -> None: self.class_types.pop() - def EnterFunction(self, f): + def EnterFunction(self, f) -> None: if self.class_types: self.method_kind = f.kind - def LeaveFunction(self, f): + def LeaveFunction(self, f) -> None: if self.class_types: self.method_kind = None - def VisitClass(self, node): + def VisitClass(self, node: _T0) -> _T0: return node def VisitParameter(self, p): @@ -1077,24 +1089,24 @@ class ~unknown2: def f(x) -> Any """ - def __init__(self): + def __init__(self) -> None: super().__init__() self.parameter = None - def EnterParameter(self, p): + def EnterParameter(self, p) -> None: self.parameter = p - def LeaveParameter(self, p): + def LeaveParameter(self, p) -> None: assert self.parameter is p self.parameter = None - def VisitClassType(self, t): + def VisitClassType(self, t: _T0) -> pytd.AnythingType | _T0: if escape.is_unknown(t.name): return pytd.AnythingType() else: return t - def VisitNamedType(self, t): + def VisitNamedType(self, t: _T0) -> pytd.AnythingType | _T0: if escape.is_unknown(t.name): return pytd.AnythingType() else: @@ -1111,12 +1123,12 @@ def VisitTypeDeclUnit(self, u): class _CountUnknowns(Visitor): """Visitor for counting how often given unknowns occur in a type.""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.counter = collections.Counter() self.position = {} - def EnterNamedType(self, t): + def EnterNamedType(self, t) -> None: _, is_unknown, suffix = t.name.partition(escape.UNKNOWN) if is_unknown: if suffix not in self.counter: @@ -1160,22 +1172,22 @@ def __enter__(self: _TFoo) -> _TFoo PREFIX = "_T" # Prefix for new type params - def __init__(self): + def __init__(self) -> None: super().__init__() self.parameter = None self.class_name = None self.function_name = None - def EnterClass(self, node): + def EnterClass(self, node) -> None: self.class_name = node.name - def LeaveClass(self, _): + def LeaveClass(self, _) -> None: self.class_name = None - def EnterFunction(self, node): + def EnterFunction(self, node) -> None: self.function_name = node.name - def LeaveFunction(self, _): + def LeaveFunction(self, _) -> None: self.function_name = None def _NeedsClassParam(self, sig): @@ -1242,7 +1254,7 @@ def VisitSignature(self, sig): sig = sig.Visit(ReplaceTypesByName(replacements)) return sig - def EnterTypeDeclUnit(self, _): + def EnterTypeDeclUnit(self, _) -> None: self.added_new_type_params = False def VisitTypeDeclUnit(self, unit): @@ -1257,11 +1269,11 @@ class VerifyVisitor(Visitor): _all_templates: set[pytd.Node] - def __init__(self): + def __init__(self) -> None: super().__init__() self._valid_param_name = re.compile(r"[a-zA-Z_]\w*$") - def _AssertNoDuplicates(self, node, attrs): + def _AssertNoDuplicates(self, node, attrs) -> None: """Checks that we don't have duplicate top-level names.""" get_set = lambda attr: {entry.name for entry in getattr(node, attr)} attr_to_set = {attr: get_set(attr) for attr in attrs} @@ -1276,13 +1288,13 @@ def _AssertNoDuplicates(self, node, attrs): f"Duplicate name(s) {list(both)} in both {a1} and {a2}" ) - def EnterTypeDeclUnit(self, node): + def EnterTypeDeclUnit(self, node) -> None: self._AssertNoDuplicates( node, ["constants", "type_params", "classes", "functions", "aliases"] ) self._all_templates = set() - def LeaveTypeDeclUnit(self, node): + def LeaveTypeDeclUnit(self, node) -> None: declared_type_params = {n.name for n in node.type_params} for t in self._all_templates: if t.name not in declared_type_params: @@ -1292,25 +1304,25 @@ def LeaveTypeDeclUnit(self, node): % t.name ) - def EnterClass(self, node): + def EnterClass(self, node) -> None: self._AssertNoDuplicates(node, ["methods", "constants"]) - def EnterFunction(self, node): + def EnterFunction(self, node) -> None: assert node.signatures, node - def EnterSignature(self, node): + def EnterSignature(self, node) -> None: assert isinstance(node.has_optional, bool), node - def EnterTemplateItem(self, node): + def EnterTemplateItem(self, node) -> None: self._all_templates.add(node) - def EnterParameter(self, node): + def EnterParameter(self, node) -> None: assert self._valid_param_name.match(node.name), node.name - def EnterCallableType(self, node): + def EnterCallableType(self, node) -> None: self.EnterGenericType(node) - def EnterGenericType(self, node): + def EnterGenericType(self, node) -> None: assert node.parameters, node @@ -1344,17 +1356,17 @@ def bar(x: baz.Foo) -> baz.Foo References to attributes of Any-typed constants will be resolved to Any. """ - def __init__(self): + def __init__(self) -> None: super().__init__() self.cls_stack = [] self.classes = None self.prefix = None self.name = None - def _ClassStackString(self): + def _ClassStackString(self) -> str: return ".".join(cls.name for cls in self.cls_stack) - def EnterTypeDeclUnit(self, node): + def EnterTypeDeclUnit(self, node) -> None: self.classes = {cls.name for cls in node.classes} # TODO(b/293451396): In certain weird cases, a local module named "typing" # may get mixed up with the stdlib typing module. We end up doing the right @@ -1367,10 +1379,10 @@ def EnterTypeDeclUnit(self, node): self.name = node.name self.prefix = node.name + "." - def EnterClass(self, cls): + def EnterClass(self, cls) -> None: self.cls_stack.append(cls) - def LeaveClass(self, cls): + def LeaveClass(self, cls) -> None: assert self.cls_stack[-1] is cls self.cls_stack.pop() @@ -1448,7 +1460,7 @@ def VisitModule(self, node): class RemoveNamePrefix(Visitor): """Visitor which removes the fully-qualified-names added by AddNamePrefix.""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.cls_stack: list[pytd.Class] = [] self.classes: set[str] = set() @@ -1559,12 +1571,12 @@ class CollectDependencies(Visitor): Needs to be called on a TypeDeclUnit. """ - def __init__(self): + def __init__(self) -> None: super().__init__() self.dependencies = {} self.late_dependencies = {} - def _ProcessName(self, name, dependencies): + def _ProcessName(self, name, dependencies) -> None: """Retrieve a module name from a node name.""" module_name, dot, base_name = name.rpartition(".") if dot: @@ -1580,16 +1592,16 @@ def _ProcessName(self, name, dependencies): # and fail later on. logging.warning("Empty package name: %s", name) - def EnterClassType(self, node): + def EnterClassType(self, node) -> None: self._ProcessName(node.name, self.dependencies) - def EnterNamedType(self, node): + def EnterNamedType(self, node) -> None: self._ProcessName(node.name, self.dependencies) - def EnterLateType(self, node): + def EnterLateType(self, node) -> None: self._ProcessName(node.name, self.late_dependencies) - def EnterModule(self, node): + def EnterModule(self, node) -> None: # Most module nodes look like: # Module(name='foo_module.bar_module', module_name='bar_module'). # We don't care about these. Nodes that don't follow this pattern are @@ -1674,7 +1686,7 @@ class AdjustTypeParameters(Visitor): * Adds scopes to type parameters. """ - def __init__(self): + def __init__(self) -> None: super().__init__() self.class_typeparams = set() self.function_typeparams = None @@ -1685,7 +1697,7 @@ def __init__(self): self.all_typevariables = set() self.generic_level = 0 - def _GetTemplateItems(self, param): + def _GetTemplateItems(self, param) -> list: """Get a list of template items from a parameter.""" items = [] if isinstance(param, pytd.GenericType): @@ -1715,7 +1727,7 @@ def VisitTypeDeclUnit(self, node): ) return node.Replace(type_params=new_type_params) - def _CheckDuplicateNames(self, params, class_name): + def _CheckDuplicateNames(self, params, class_name) -> None: seen = set() for x in params: if x.name in seen: @@ -1725,7 +1737,7 @@ def _CheckDuplicateNames(self, params, class_name): ) seen.add(x.name) - def EnterClass(self, node): + def EnterClass(self, node) -> None: """Establish the template for the class.""" templates = [] generic_template = None @@ -1773,7 +1785,7 @@ def EnterClass(self, node): self.class_name = node.name - def LeaveClass(self, node): + def LeaveClass(self, node) -> None: del node for t in self.class_template[-1]: if t.name in self.class_typeparams: @@ -1794,11 +1806,11 @@ def VisitClass(self, node): node = node.Replace(template=tuple(template)) return node.Visit(AdjustSelf()).Visit(NamedTypeToClassType()) - def EnterSignature(self, unused_node): + def EnterSignature(self, unused_node) -> None: assert self.function_typeparams is None, self.function_typeparams self.function_typeparams = set() - def LeaveSignature(self, unused_node): + def LeaveSignature(self, unused_node) -> None: self.function_typeparams = None def _MaybeMutateSelf(self, sig): @@ -1835,43 +1847,43 @@ def VisitSignature(self, node): node.Replace(template=tuple(sorted(self.function_typeparams))) ) - def EnterFunction(self, node): + def EnterFunction(self, node) -> None: self.function_name = node.name - def LeaveFunction(self, unused_node): + def LeaveFunction(self, unused_node) -> None: self.function_name = None - def EnterConstant(self, node): + def EnterConstant(self, node) -> None: self.constant_name = node.name - def LeaveConstant(self, unused_node): + def LeaveConstant(self, unused_node) -> None: self.constant_name = None - def EnterGenericType(self, unused_node): + def EnterGenericType(self, unused_node) -> None: self.generic_level += 1 - def LeaveGenericType(self, unused_node): + def LeaveGenericType(self, unused_node) -> None: self.generic_level -= 1 - def EnterCallableType(self, node): + def EnterCallableType(self, node) -> None: self.EnterGenericType(node) - def LeaveCallableType(self, node): + def LeaveCallableType(self, node) -> None: self.LeaveGenericType(node) - def EnterTupleType(self, node): + def EnterTupleType(self, node) -> None: self.EnterGenericType(node) - def LeaveTupleType(self, node): + def LeaveTupleType(self, node) -> None: self.LeaveGenericType(node) - def EnterUnionType(self, node): + def EnterUnionType(self, node) -> None: self.EnterGenericType(node) - def LeaveUnionType(self, node): + def LeaveUnionType(self, node) -> None: self.LeaveGenericType(node) - def _GetFullName(self, name): + def _GetFullName(self, name) -> str: return ".".join(n for n in [self.class_name, name] if n) def _GetScope(self, name): @@ -1933,7 +1945,7 @@ class VerifyContainers(Visitor): ContainerError: If a problematic container definition is encountered. """ - def EnterGenericType(self, node): + def EnterGenericType(self, node) -> None: """Verify a pytd.GenericType.""" base_type = node.base_type if isinstance(base_type, pytd.LateType): @@ -1954,13 +1966,13 @@ def EnterGenericType(self, node): ) ) - def EnterCallableType(self, node): + def EnterCallableType(self, node) -> None: self.EnterGenericType(node) - def EnterTupleType(self, node): + def EnterTupleType(self, node) -> None: self.EnterGenericType(node) - def _GetGenericBasesLookupMap(self, node): + def _GetGenericBasesLookupMap(self, node) -> collections.defaultdict: """Get a lookup map for the generic bases of a class. Gets a map from a pytd.ClassType to the list of pytd.GenericType bases of @@ -1991,7 +2003,7 @@ def _GetGenericBasesLookupMap(self, node): bases.extend(reversed(base.cls.bases)) return mapping - def _UpdateParamToValuesMapping(self, mapping, param, value): + def _UpdateParamToValuesMapping(self, mapping, param, value) -> None: """Update the given mapping of parameter names to values.""" param_name = param.type_param.full_name if isinstance(value, pytd.TypeParameter): @@ -2010,7 +2022,7 @@ def _UpdateParamToValuesMapping(self, mapping, param, value): mapping[param_name] = set() mapping[param_name].add(value) - def _TypeCompatibilityCheck(self, type_params): + def _TypeCompatibilityCheck(self, type_params) -> bool: """Check if the types are compatible. It is used to handle the case: @@ -2041,7 +2053,7 @@ class C(B, Sequence[C]): pass prev = cur return True - def EnterClass(self, node): + def EnterClass(self, node) -> None: """Check for conflicting type parameter values in the class's bases.""" # Get the bases in MRO, since we need to know the order in which type # parameters are aliased or assigned values. @@ -2093,7 +2105,7 @@ class VerifyLiterals(Visitor): ClassType pointers are filled in. """ - def EnterLiteral(self, node): + def EnterLiteral(self, node) -> None: value = node.value if not isinstance(value, pytd.Constant): # This Literal does not hold an object, no need to check further. @@ -2149,18 +2161,18 @@ def EnterLiteral(self, node): class ClearClassPointers(Visitor): """Set .cls pointers to 'None'.""" - def EnterClassType(self, node): + def EnterClassType(self, node) -> None: node.cls = None class ReplaceModulesWithAny(_RemoveTypeParametersFromGenericAny): """Replace all references to modules in a list with AnythingType.""" - def __init__(self, module_list: list[str]): + def __init__(self, module_list: list[str]) -> None: super().__init__() self._any_modules = module_list - def VisitNamedType(self, n): + def VisitNamedType(self, n: _T0) -> pytd.AnythingType | _T0: if any(n.name.startswith(module) for module in self._any_modules): return pytd.AnythingType() return n @@ -2174,14 +2186,14 @@ def VisitClassType(self, n): class ReplaceUnionsWithAny(Visitor): - def VisitUnionType(self, _): + def VisitUnionType(self, _) -> pytd.AnythingType: return pytd.AnythingType() class ClassTypeToLateType(Visitor): """Convert ClassType to LateType.""" - def __init__(self, ignore): + def __init__(self, ignore) -> None: """Initialize the visitor. Args: @@ -2203,7 +2215,7 @@ def VisitClassType(self, n): class LateTypeToClassType(Visitor): """Convert LateType to (unresolved) ClassType.""" - def VisitLateType(self, t): + def VisitLateType(self, t) -> pytd.ClassType: return pytd.ClassType(t.name, None) diff --git a/pytype/pytype_source_utils.py b/pytype/pytype_source_utils.py index 96eb2baf2..8d9975685 100644 --- a/pytype/pytype_source_utils.py +++ b/pytype/pytype_source_utils.py @@ -6,6 +6,7 @@ import os import re +from typing import Any, Generator from pytype.platform_utils import path_utils @@ -14,7 +15,7 @@ class NoSuchDirectory(Exception): # pylint: disable=g-bad-exception-name pass -def _pytype_source_dir(): +def _pytype_source_dir() -> str: """The base directory of the pytype source tree.""" res = path_utils.dirname(__file__) if path_utils.basename(res) == "__pycache__": @@ -76,7 +77,7 @@ def _load_data_file(filename, text): return fi.read() -def list_files(basedir): +def list_files(basedir) -> Generator[str, Any, None]: """List files in the directory rooted at |basedir|.""" if not path_utils.isdir(basedir): raise NoSuchDirectory(basedir) @@ -91,7 +92,7 @@ def list_files(basedir): yield filename -def list_pytype_files(suffix): +def list_pytype_files(suffix) -> Generator[Any, Any, None]: """Recursively get the contents of a directory in the pytype installation. This reports files in said directory as well as all subdirectories of it. diff --git a/pytype/rewrite/abstract/abstract.py b/pytype/rewrite/abstract/abstract.py index d5b163372..e13f58c7a 100644 --- a/pytype/rewrite/abstract/abstract.py +++ b/pytype/rewrite/abstract/abstract.py @@ -1,5 +1,6 @@ """Abstract representations of Python values.""" +from typing import TypeVar from pytype.rewrite.abstract import base as _base from pytype.rewrite.abstract import classes as _classes from pytype.rewrite.abstract import containers as _containers @@ -7,38 +8,42 @@ from pytype.rewrite.abstract import internal as _internal from pytype.rewrite.abstract import utils as _utils -BaseValue = _base.BaseValue -ContextType = _base.ContextType -PythonConstant = _base.PythonConstant -Singleton = _base.Singleton -Union = _base.Union - -SimpleClass = _classes.SimpleClass -BaseInstance = _classes.BaseInstance -FrozenInstance = _classes.FrozenInstance -InterpreterClass = _classes.InterpreterClass -Module = _classes.Module -MutableInstance = _classes.MutableInstance - -Args = _functions.Args -BaseFunction = _functions.BaseFunction -BoundFunction = _functions.BoundFunction -FrameType = _functions.FrameType -InterpreterFunction = _functions.InterpreterFunction -MappedArgs = _functions.MappedArgs -PytdFunction = _functions.PytdFunction -Signature = _functions.Signature -SimpleFunction = _functions.SimpleFunction -SimpleReturn = _functions.SimpleReturn - -Dict = _containers.Dict -List = _containers.List -Set = _containers.Set -Tuple = _containers.Tuple - -FunctionArgDict = _internal.FunctionArgDict -FunctionArgTuple = _internal.FunctionArgTuple -Splat = _internal.Splat +_T = TypeVar('_T') + +BaseValue: type[_base.BaseValue] = _base.BaseValue +ContextType: type[_base.ContextType] = _base.ContextType +PythonConstant: type[_base.PythonConstant] = _base.PythonConstant +Singleton: type[_base.Singleton] = _base.Singleton +Union: type[_base.Union] = _base.Union + +SimpleClass: type[_classes.SimpleClass] = _classes.SimpleClass +BaseInstance: type[_classes.BaseInstance] = _classes.BaseInstance +FrozenInstance: type[_classes.FrozenInstance] = _classes.FrozenInstance +InterpreterClass: type[_classes.InterpreterClass] = _classes.InterpreterClass +Module: type[_classes.Module] = _classes.Module +MutableInstance: type[_classes.MutableInstance] = _classes.MutableInstance + +Args: type[_functions.Args] = _functions.Args +BaseFunction: type[_functions.BaseFunction] = _functions.BaseFunction +BoundFunction: type[_functions.BoundFunction] = _functions.BoundFunction +FrameType: type[_functions.FrameType] = _functions.FrameType +InterpreterFunction: type[_functions.InterpreterFunction] = ( + _functions.InterpreterFunction +) +MappedArgs: type[_functions.MappedArgs] = _functions.MappedArgs +PytdFunction: type[_functions.PytdFunction] = _functions.PytdFunction +Signature: type[_functions.Signature] = _functions.Signature +SimpleFunction: type[_functions.SimpleFunction] = _functions.SimpleFunction +SimpleReturn: type[_functions.SimpleReturn] = _functions.SimpleReturn + +Dict: type[_containers.Dict] = _containers.Dict +List: type[_containers.List] = _containers.List +Set: type[_containers.Set] = _containers.Set +Tuple: type[_containers.Tuple] = _containers.Tuple + +FunctionArgDict: type[_internal.FunctionArgDict] = _internal.FunctionArgDict +FunctionArgTuple: type[_internal.FunctionArgTuple] = _internal.FunctionArgTuple +Splat: type[_internal.Splat] = _internal.Splat get_atomic_constant = _utils.get_atomic_constant join_values = _utils.join_values diff --git a/pytype/rewrite/abstract/base.py b/pytype/rewrite/abstract/base.py index 5d347788d..408ad16ab 100644 --- a/pytype/rewrite/abstract/base.py +++ b/pytype/rewrite/abstract/base.py @@ -12,6 +12,11 @@ from pytype.types import types from typing_extensions import Self +_SelfBaseValue = TypeVar('_SelfBaseValue', bound=types.BaseValue) +_TBaseValue = TypeVar('_TBaseValue', bound=types.BaseValue) +_TSingleton = TypeVar('_TSingleton', bound='Singleton') +_TUnion = TypeVar('_TUnion', bound='Union') + _T = TypeVar('_T') @@ -36,7 +41,7 @@ class BaseValue(types.BaseValue, abc.ABC): # to define it. name = '' - def __init__(self, ctx: ContextType): + def __init__(self, ctx: ContextType) -> None: self._ctx = ctx @abc.abstractmethod @@ -58,10 +63,10 @@ def _attrs(self) -> tuple[Any, ...]: def full_name(self): return self.name - def __eq__(self, other): + def __eq__(self, other) -> bool: return self.__class__ == other.__class__ and self._attrs == other._attrs - def __hash__(self): + def __hash__(self) -> int: return hash((self.__class__, self._ctx) + self._attrs) def to_variable(self, name: str | None = None) -> variables.Variable[Self]: @@ -97,14 +102,17 @@ class PythonConstant(BaseValue, Generic[_T]): """ def __init__( - self, ctx: ContextType, constant: _T, allow_direct_instantiation=False): + self, ctx: ContextType, constant: _T, allow_direct_instantiation=False + ) -> None: if self.__class__ is PythonConstant and not allow_direct_instantiation: - raise ValueError('Do not instantiate PythonConstant directly. Use ' - 'ctx.consts[constant] instead.') + raise ValueError( + 'Do not instantiate PythonConstant directly. Use ' + 'ctx.consts[constant] instead.' + ) super().__init__(ctx) self.constant = constant - def __repr__(self): + def __repr__(self) -> str: return f'PythonConstant({self.constant!r})' @property @@ -122,7 +130,7 @@ class Singleton(BaseValue): name: str - def __init__(self, ctx, name, allow_direct_instantiation=False): + def __init__(self, ctx, name, allow_direct_instantiation=False) -> None: if self.__class__ is Singleton and not allow_direct_instantiation: raise ValueError('Do not instantiate Singleton directly. Use ' 'ctx.consts.singles[name] instead.') @@ -146,7 +154,7 @@ def get_attribute(self, name: str) -> 'Singleton': class Union(BaseValue): """Union of values.""" - def __init__(self, ctx: ContextType, options: Sequence[BaseValue]): + def __init__(self, ctx: ContextType, options: Sequence[BaseValue]) -> None: super().__init__(ctx) assert len(options) > 1 flattened_options = [] @@ -157,14 +165,14 @@ def __init__(self, ctx: ContextType, options: Sequence[BaseValue]): flattened_options.append(o) self.options = tuple(utils.unique_list(flattened_options)) - def __repr__(self): + def __repr__(self) -> str: return ' | '.join(repr(o) for o in self.options) @property def _attrs(self): return (frozenset(self.options),) - def instantiate(self): + def instantiate(self: _TUnion) -> _TUnion: return Union(self._ctx, tuple(o.instantiate() for o in self.options)) diff --git a/pytype/rewrite/abstract/classes.py b/pytype/rewrite/abstract/classes.py index 5babfb06d..2d6ea6f79 100644 --- a/pytype/rewrite/abstract/classes.py +++ b/pytype/rewrite/abstract/classes.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence import dataclasses import logging -from typing import Optional, Protocol +from typing import Optional, Protocol, TypeVar from pytype import datatypes from pytype.pytd import mro as mro_lib @@ -12,7 +12,10 @@ from pytype.rewrite.abstract import functions as functions_lib from pytype.types import types -log = logging.getLogger(__name__) + +_TSimpleClass = TypeVar('_TSimpleClass', bound='SimpleClass') + +log: logging.Logger = logging.getLogger(__name__) class _HasMembers(Protocol): @@ -25,7 +28,7 @@ class ClassCallReturn: instance: 'MutableInstance' - def get_return_value(self): + def get_return_value(self) -> 'MutableInstance': return self.instance @@ -68,7 +71,7 @@ def __init__( # instance methods called on an instance immediately after creation self.initializers = ['__init__'] - def __repr__(self): + def __repr__(self) -> str: return f'SimpleClass({self.full_name})' @property @@ -148,7 +151,7 @@ def mro(self) -> Sequence['SimpleClass']: self._mro = mro = mro_lib.MROMerge(mro_bases) return mro - def set_type_parameters(self, params): + def set_type_parameters(self: _TSimpleClass, params) -> _TSimpleClass: # A dummy implementation to let type annotations with parameters not crash. del params # not implemented yet # We eventually want to return a new class with the type parameters set @@ -174,7 +177,7 @@ def __init__( self.functions = functions self.classes = classes - def __repr__(self): + def __repr__(self) -> str: return f'InterpreterClass({self.name})' @property @@ -213,7 +216,7 @@ class MutableInstance(BaseInstance): def __init__(self, ctx: base.ContextType, cls: SimpleClass): super().__init__(ctx, cls, {}) - def __repr__(self): + def __repr__(self) -> str: return f'MutableInstance({self.cls.name})' @property @@ -241,7 +244,7 @@ def __init__(self, ctx: base.ContextType, instance: MutableInstance): super().__init__( ctx, instance.cls, datatypes.immutabledict(instance.members)) - def __repr__(self): + def __repr__(self) -> str: return f'FrozenInstance({self.cls.name})' @property @@ -264,7 +267,7 @@ def __init__(self, ctx: base.ContextType, name: str): super().__init__(ctx, cls, members={}) self.name = name - def __repr__(self): + def __repr__(self) -> str: return f'Module({self.name})' @property diff --git a/pytype/rewrite/abstract/containers.py b/pytype/rewrite/abstract/containers.py index ee6b1922f..228eb1e83 100644 --- a/pytype/rewrite/abstract/containers.py +++ b/pytype/rewrite/abstract/containers.py @@ -1,12 +1,17 @@ """Abstract representations of builtin containers.""" import logging +from typing import TypeVar from pytype.rewrite.abstract import base from pytype.rewrite.abstract import internal from pytype.rewrite.abstract import utils -log = logging.getLogger(__name__) +_TDict = TypeVar('_TDict', bound='Dict') +_TList = TypeVar('_TList', bound='List') +_TSet = TypeVar('_TSet', bound='Set') + +log: logging.Logger = logging.getLogger(__name__) # Type aliases _Var = base.AbstractVariableType @@ -19,7 +24,7 @@ def __init__(self, ctx: base.ContextType, constant: list[_Var]): assert isinstance(constant, list), constant super().__init__(ctx, constant) - def __repr__(self): + def __repr__(self) -> str: return f'List({self.constant!r})' def append(self, var: _Var) -> 'List': @@ -41,7 +46,7 @@ def __init__( assert isinstance(constant, dict), constant super().__init__(ctx, constant) - def __repr__(self): + def __repr__(self) -> str: return f'Dict({self.constant!r})' @classmethod @@ -63,7 +68,7 @@ def update(self, val: 'Dict') -> base.BaseValue: def to_function_arg_dict(self) -> internal.FunctionArgDict: new_const = { - utils.get_atomic_constant(k, str): v + utils.get_atomic_constant(k, str): v # pytype: disable=wrong-arg-types for k, v in self.constant.items() } return internal.FunctionArgDict(self._ctx, new_const) @@ -76,7 +81,7 @@ def __init__(self, ctx: base.ContextType, constant: set[_Var]): assert isinstance(constant, set), constant super().__init__(ctx, constant) - def __repr__(self): + def __repr__(self) -> str: return f'Set({self.constant!r})' def add(self, val: _Var) -> 'Set': @@ -90,5 +95,5 @@ def __init__(self, ctx: base.ContextType, constant: tuple[_Var, ...]): assert isinstance(constant, tuple), constant super().__init__(ctx, constant) - def __repr__(self): + def __repr__(self) -> str: return f'Tuple({self.constant!r})' diff --git a/pytype/rewrite/abstract/functions.py b/pytype/rewrite/abstract/functions.py index 709a63da5..a22428d9a 100644 --- a/pytype/rewrite/abstract/functions.py +++ b/pytype/rewrite/abstract/functions.py @@ -28,8 +28,13 @@ from pytype.rewrite.abstract import base from pytype.rewrite.abstract import containers from pytype.rewrite.abstract import internal +from pytype.rewrite.flow import variables -log = logging.getLogger(__name__) + +_TFrameType = TypeVar('_TFrameType', bound='FrameType') +_TSignature = TypeVar('_TSignature', bound=pytd.Signature) + +log: logging.Logger = logging.getLogger(__name__) _Var = base.AbstractVariableType _ArgDict = dict[str, _Var] @@ -60,7 +65,7 @@ def load_attr(self, target_var: _Var, attr_name: str) -> _Var: ... _FrameT = TypeVar('_FrameT', bound=FrameType) -def _unpack_splats(elts): +def _unpack_splats(elts) -> tuple: """Unpack any concrete splats and splice them into the sequence.""" ret = [] for e in elts: @@ -106,7 +111,7 @@ def __init__(self, ctx: base.ContextType, args: Args, sig: 'Signature'): self.sig = sig self.argdict: _ArgDict = {} - def _expand_positional_args(self): + def _expand_positional_args(self) -> None: """Unpack concrete splats in posargs.""" new_posargs = _unpack_splats(self.args.posargs) self.args = dataclasses.replace(self.args, posargs=new_posargs) @@ -219,7 +224,7 @@ def _unpack_starargs(self) -> tuple[tuple[_Var, ...], _Var | None]: # We have **kwargs but no *args in the invocation return tuple(pre), None - def _map_posargs(self): + def _map_posargs(self) -> None: posargs, starargs = self._unpack_starargs() argdict = dict(zip(self.sig.param_names, posargs)) self.argdict.update(argdict) @@ -229,7 +234,9 @@ def _map_posargs(self): starargs = self._ctx.consts.Any.to_variable() self.argdict[self.sig.varargs_name] = starargs - def _unpack_starstarargs(self): + def _unpack_starstarargs( + self, + ) -> tuple[dict, variables.Variable[internal.FunctionArgDict]]: """Adjust **args and kwargs based on function signature.""" starstarargs_var = self.args.starstarargs if starstarargs_var is None: @@ -260,7 +267,7 @@ def _unpack_starstarargs(self): self._ctx, starstarargs_dict, starstarargs.indefinite) return kwargs_dict, new_starstarargs.to_variable() - def _map_kwargs(self): + def _map_kwargs(self) -> None: kwargs, starstarargs = self._unpack_starstarargs() # Copy kwargs into argdict self.argdict.update(kwargs) @@ -268,7 +275,7 @@ def _map_kwargs(self): if self.sig.kwargs_name: self.argdict[self.sig.kwargs_name] = starstarargs - def map_args(self): + def map_args(self) -> dict: self._expand_positional_args() self._map_kwargs() self._map_posargs() @@ -296,7 +303,7 @@ class SimpleReturn: def __init__(self, return_value: base.BaseValue): self._return_value = return_value - def get_return_value(self): + def get_return_value(self) -> base.BaseValue: return self._return_value @@ -414,7 +421,7 @@ def from_pytd( annotations=annotations, ) - def __repr__(self): + def __repr__(self) -> str: pp = self._ctx.errorlog.pretty_printer def fmt(param_name): @@ -514,7 +521,7 @@ def __init__( self._signatures = signatures self.module = module - def __repr__(self): + def __repr__(self) -> str: return f'SimpleFunction({self.full_name})' @property @@ -589,7 +596,7 @@ def __init__( self._parent_frame = parent_frame self._call_cache = {} - def __repr__(self): + def __repr__(self) -> str: return f'InterpreterFunction({self.name})' @property @@ -639,7 +646,7 @@ def __init__( self.callself = callself self.underlying = underlying - def __repr__(self): + def __repr__(self) -> str: return f'BoundFunction({self.callself!r}, {self.underlying!r})' @property diff --git a/pytype/rewrite/abstract/internal.py b/pytype/rewrite/abstract/internal.py index 27ec36311..503d0f764 100644 --- a/pytype/rewrite/abstract/internal.py +++ b/pytype/rewrite/abstract/internal.py @@ -1,11 +1,13 @@ """Abstract types used internally by pytype.""" import collections +from typing import TypeVar import immutabledict - from pytype.rewrite.abstract import base +_TSplat = TypeVar("_TSplat", bound="Splat") + # Type aliases _Var = base.AbstractVariableType @@ -25,7 +27,7 @@ def __init__( self.constant = constant self.indefinite = indefinite - def __repr__(self): + def __repr__(self) -> str: indef = "+" if self.indefinite else "" return f"FunctionArgTuple({indef}{self.constant!r})" @@ -55,7 +57,7 @@ def _check_keys(self, constant: dict[str, _Var]): if not all(isinstance(k, str) for k in constant): raise ValueError("Passing a non-string key to a function arg dict") - def __repr__(self): + def __repr__(self) -> str: indef = "+" if self.indefinite else "" return f"FunctionArgDict({indef}{self.constant!r})" @@ -87,7 +89,7 @@ def get_concrete_iterable(self): else: raise ValueError("Not a concrete iterable") - def __repr__(self): + def __repr__(self) -> str: return f"splat({self.iterable!r})" @property diff --git a/pytype/rewrite/abstract/utils.py b/pytype/rewrite/abstract/utils.py index 6b4e380a0..99e2405be 100644 --- a/pytype/rewrite/abstract/utils.py +++ b/pytype/rewrite/abstract/utils.py @@ -1,9 +1,10 @@ """Utilities for working with abstract values.""" from collections.abc import Sequence -from typing import Any, TypeVar, get_origin, overload +from typing import Any, get_origin, overload, TypeVar from pytype.rewrite.abstract import base +from pytype.rewrite.flow import variables _Var = base.AbstractVariableType @@ -20,18 +21,22 @@ def get_atomic_constant(var: _Var, typ: None = ...) -> Any: ... -def get_atomic_constant(var, typ=None): +def get_atomic_constant( + var: variables.Variable[base.BaseValue], typ: None = None +): value = var.get_atomic_value(base.PythonConstant) constant = value.constant if typ and not isinstance(constant, (runtime_type := get_origin(typ) or typ)): raise ValueError( f'Wrong constant type for {var.display_name()}: expected ' - f'{runtime_type.__name__}, got {constant.__class__.__name__}') + f'{runtime_type.__name__}, got {constant.__class__.__name__}' + ) return constant def join_values( - ctx: base.ContextType, values: Sequence[base.BaseValue]) -> base.BaseValue: + ctx: base.ContextType, values: Sequence[base.BaseValue] +) -> base.BaseValue: if len(values) > 1: return base.Union(ctx, values) elif values: diff --git a/pytype/rewrite/analyze.py b/pytype/rewrite/analyze.py index 564f7d7d4..3fb13ad49 100644 --- a/pytype/rewrite/analyze.py +++ b/pytype/rewrite/analyze.py @@ -13,7 +13,7 @@ _INIT_MAXIMUM_DEPTH = 4 # during module loading _MAXIMUM_DEPTH = 3 # during analysis of function bodies -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) @dataclasses.dataclass diff --git a/pytype/rewrite/convert.py b/pytype/rewrite/convert.py index 246647f6f..6b0cdc6b3 100644 --- a/pytype/rewrite/convert.py +++ b/pytype/rewrite/convert.py @@ -8,7 +8,7 @@ class _Cache: - def __init__(self): + def __init__(self) -> None: self.classes = {} self.funcs = {} self.types = {} diff --git a/pytype/rewrite/flow/conditions.py b/pytype/rewrite/flow/conditions.py index a4573a9dd..69d094886 100644 --- a/pytype/rewrite/flow/conditions.py +++ b/pytype/rewrite/flow/conditions.py @@ -1,7 +1,9 @@ """Variables, bindings, and conditions.""" import dataclasses -from typing import ClassVar +from typing import TypeVar, ClassVar + +_T = TypeVar('_T') _frozen_dataclass = dataclasses.dataclass(frozen=True) @@ -14,19 +16,19 @@ class Condition: @_frozen_dataclass class _True(Condition): - def __repr__(self): + def __repr__(self) -> str: return 'TRUE' @_frozen_dataclass class _False(Condition): - def __repr__(self): + def __repr__(self) -> str: return 'FALSE' -TRUE = _True() -FALSE = _False() +TRUE: _True = _True() +FALSE: _False = _False() @_frozen_dataclass @@ -35,7 +37,7 @@ class _Not(Condition): condition: Condition - def __repr__(self): + def __repr__(self) -> str: return f'not {self.condition}' @classmethod @@ -77,7 +79,7 @@ def make(cls, *args: Condition) -> Condition: return conditions.pop() return cls(frozenset(conditions)) - def __repr__(self): + def __repr__(self) -> str: conditions = [] for c in self.conditions: if isinstance(c, _Composite): diff --git a/pytype/rewrite/flow/frame_base.py b/pytype/rewrite/flow/frame_base.py index bb62e107a..406c13b12 100644 --- a/pytype/rewrite/flow/frame_base.py +++ b/pytype/rewrite/flow/frame_base.py @@ -18,7 +18,7 @@ _T = TypeVar('_T') -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) _FINAL = -1 diff --git a/pytype/rewrite/flow/state.py b/pytype/rewrite/flow/state.py index dd47b7819..24cdddf16 100644 --- a/pytype/rewrite/flow/state.py +++ b/pytype/rewrite/flow/state.py @@ -28,7 +28,7 @@ def __init__( else: self._locals_with_block_condition = locals_with_block_condition - def __repr__(self): + def __repr__(self) -> str: return (f'BlockState(locals={self._locals}, condition={self._condition}, ' f'locals_with_block_condition={self._locals_with_block_condition})') diff --git a/pytype/rewrite/flow/variables.py b/pytype/rewrite/flow/variables.py index 59865c211..4ea7a97d6 100644 --- a/pytype/rewrite/flow/variables.py +++ b/pytype/rewrite/flow/variables.py @@ -17,7 +17,7 @@ class Binding(Generic[_T]): value: _T condition: conditions.Condition = conditions.TRUE - def __repr__(self): + def __repr__(self) -> str: if self.condition is conditions.TRUE: return f'Bind[{self.value}]' return f'Bind[{self.value} if {self.condition}]' @@ -61,17 +61,19 @@ def get_atomic_value(self, typ: type[_T2]) -> _T2: def get_atomic_value(self, typ: None = ...) -> _T: ... - def get_atomic_value(self, typ=None): + def get_atomic_value(self, typ: None = None) -> '_T': """Gets this variable's value if there's exactly one, errors otherwise.""" if not self.is_atomic(): desc = 'many' if len(self.bindings) > 1 else 'few' raise ValueError( - f'Too {desc} bindings for {self.display_name()}: {self.bindings}') + f'Too {desc} bindings for {self.display_name()}: {self.bindings}' + ) value = self.bindings[0].value if typ and not isinstance(value, (runtime_type := get_origin(typ) or typ)): raise ValueError( f'Wrong type for {self.display_name()}: expected ' - f'{runtime_type.__name__}, got {value.__class__.__name__}') + f'{runtime_type.__name__}, got {value.__class__.__name__}' + ) return value def is_atomic(self, typ: type[_T] | None = None) -> bool: @@ -100,7 +102,7 @@ def with_value(self, value: _T2) -> 'Variable[_T2]': new_binding = dataclasses.replace(self.bindings[0], value=value) return dataclasses.replace(self, bindings=(new_binding,)) - def __repr__(self): + def __repr__(self) -> str: bindings = ' | '.join(repr(b) for b in self.bindings) if self.name: return f'Var[{self.name} -> {bindings}]' diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py index 98c36d7d8..1e18fea3f 100644 --- a/pytype/rewrite/frame.py +++ b/pytype/rewrite/frame.py @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence import logging -from typing import Any, Optional +from typing import Any, Optional, TypeVar from pycnite import marshal as pyc_marshal from pytype import datatypes @@ -16,7 +16,9 @@ from pytype.rewrite.flow import frame_base from pytype.rewrite.flow import variables -log = logging.getLogger(__name__) +_TFrame = TypeVar('_TFrame', bound='Frame') + +log: logging.Logger = logging.getLogger(__name__) # Type aliases _Var = variables.Variable[abstract.BaseValue] @@ -24,13 +26,13 @@ _FrameFunction = abstract.InterpreterFunction['Frame'] # This enum will be used frequently, so alias it -_Flags = pyc_marshal.Flags +_Flags: type[pyc_marshal.Flags] = pyc_marshal.Flags class _ShadowedNonlocals: """Tracks shadowed nonlocal names.""" - def __init__(self): + def __init__(self) -> None: self._enclosing: set[str] = set() self._globals: set[str] = set() @@ -40,10 +42,10 @@ def add_enclosing(self, name: str) -> None: def add_global(self, name: str) -> None: self._globals.add(name) - def has_enclosing(self, name: str): + def has_enclosing(self, name: str) -> bool: return name in self._enclosing - def has_global(self, name: str): + def has_global(self, name: str) -> bool: return name in self._globals def get_global_names(self) -> frozenset[str]: @@ -99,7 +101,7 @@ def __init__( # Handler for function calls. self._call_helper = function_call_helper.FunctionCallHelper(ctx, self) - def __repr__(self): + def __repr__(self) -> str: return f'Frame({self.name})' @classmethod @@ -160,7 +162,7 @@ def run(self) -> None: for name, var in self._final_locals.items() }) - def _log_stack(self): + def _log_stack(self) -> None: log.debug('stack: %r', self._stack) def store_local(self, name: str, var: _Var) -> None: @@ -367,7 +369,7 @@ def _load_method( self.load_attr(instance_var, method_name), ) - def _pop_jump_if_false(self, opcode): + def _pop_jump_if_false(self, opcode) -> None: unused_var = self._stack.pop() # TODO(b/324465215): Construct the real conditions for this jump. jump_state = self._current_state.with_condition(conditions.Condition()) @@ -384,18 +386,18 @@ def _replace_atomic_stack_value( # --------------------------------------------------------------- # Opcodes with no typing effects - def byte_NOP(self, opcode): + def byte_NOP(self, opcode) -> None: del opcode # unused - def byte_PRINT_EXPR(self, opcode): + def byte_PRINT_EXPR(self, opcode) -> None: del opcode # unused self._stack.pop_and_discard() - def byte_PRECALL(self, opcode): + def byte_PRECALL(self, opcode) -> None: # Internal cpython use del opcode # unused - def byte_RESUME(self, opcode): + def byte_RESUME(self, opcode) -> None: # Internal cpython use del opcode # unused @@ -413,28 +415,28 @@ def _get_const(self, oparg): val = self._ctx.consts[const] return val.to_variable() - def byte_LOAD_CONST(self, opcode): + def byte_LOAD_CONST(self, opcode) -> None: self._stack.push(self._get_const(opcode.arg)) - def byte_RETURN_VALUE(self, opcode): + def byte_RETURN_VALUE(self, opcode) -> None: self._returns.append(self._stack.pop()) - def byte_RETURN_CONST(self, opcode): + def byte_RETURN_CONST(self, opcode) -> None: self._returns.append(self._get_const(opcode.arg)) - def byte_STORE_NAME(self, opcode): + def byte_STORE_NAME(self, opcode) -> None: self.store_local(opcode.argval, self._stack.pop()) - def byte_STORE_FAST(self, opcode): + def byte_STORE_FAST(self, opcode) -> None: self.store_local(opcode.argval, self._stack.pop()) - def byte_STORE_GLOBAL(self, opcode): + def byte_STORE_GLOBAL(self, opcode) -> None: self.store_global(opcode.argval, self._stack.pop()) - def byte_STORE_DEREF(self, opcode): + def byte_STORE_DEREF(self, opcode) -> None: self.store_deref(opcode.argval, self._stack.pop()) - def byte_STORE_ATTR(self, opcode): + def byte_STORE_ATTR(self, opcode) -> None: attr_name = opcode.argval attr, target = self._stack.popn(2) if not target.name: @@ -442,25 +444,25 @@ def byte_STORE_ATTR(self, opcode): full_name = f'{target.name}.{attr_name}' self.store_local(full_name, attr) - def _unpack_function_annotations(self, packed_annot): + def _unpack_function_annotations(self, packed_annot) -> dict: if self._code.python_version >= (3, 10): # In Python 3.10+, packed_annot is a tuple of variables: # (param_name1, param_type1, param_name2, param_type2, ...) - annot_seq = abstract.get_atomic_constant(packed_annot, tuple) + annot_seq = abstract.get_atomic_constant(packed_annot, tuple) # pytype: disable=wrong-arg-types double_num_annots = len(annot_seq) assert not double_num_annots % 2 annot = {} for i in range(double_num_annots // 2): - name = abstract.get_atomic_constant(annot_seq[i * 2], str) + name = abstract.get_atomic_constant(annot_seq[i * 2], str) # pytype: disable=wrong-arg-types annot[name] = annot_seq[i * 2 + 1] else: # Pre-3.10, packed_annot was a name->param_type dictionary. - annot = abstract.get_atomic_constant(packed_annot, dict) + annot = abstract.get_atomic_constant(packed_annot, dict) # pytype: disable=wrong-arg-types return annot - def byte_MAKE_FUNCTION(self, opcode): + def byte_MAKE_FUNCTION(self, opcode) -> None: # Aliases for readability - pop_const = lambda t: abstract.get_atomic_constant(self._stack.pop(), t) + pop_const = lambda t: abstract.get_atomic_constant(self._stack.pop(), t) # pytype: disable=wrong-arg-types arg = opcode.arg # Get name and code object if self._code.python_version >= (3, 11): @@ -505,11 +507,11 @@ def byte_MAKE_FUNCTION(self, opcode): self._functions.append(func) self._stack.push(func.to_variable()) - def byte_PUSH_NULL(self, opcode): + def byte_PUSH_NULL(self, opcode) -> None: del opcode # unused self._stack.push(self._ctx.consts.singles['NULL'].to_variable()) - def byte_LOAD_NAME(self, opcode): + def byte_LOAD_NAME(self, opcode) -> None: name = opcode.argval try: var = self.load_local(name) @@ -517,19 +519,19 @@ def byte_LOAD_NAME(self, opcode): var = self.load_global(name) self._stack.push(var) - def byte_LOAD_FAST(self, opcode): + def byte_LOAD_FAST(self, opcode) -> None: name = opcode.argval self._stack.push(self.load_local(name)) - def byte_LOAD_DEREF(self, opcode): + def byte_LOAD_DEREF(self, opcode) -> None: name = opcode.argval self._stack.push(self.load_deref(name)) - def byte_LOAD_CLOSURE(self, opcode): + def byte_LOAD_CLOSURE(self, opcode) -> None: name = opcode.argval self._stack.push(self.load_deref(name)) - def byte_LOAD_GLOBAL(self, opcode): + def byte_LOAD_GLOBAL(self, opcode) -> None: if self._code.python_version >= (3, 11) and opcode.arg & 1: # Compiler-generated marker that will be consumed in byte_CALL # We are loading a global and calling it as a function. @@ -537,7 +539,7 @@ def byte_LOAD_GLOBAL(self, opcode): name = opcode.argval self._stack.push(self.load_global(name)) - def byte_LOAD_ATTR(self, opcode): + def byte_LOAD_ATTR(self, opcode) -> None: attr_name = opcode.argval target_var = self._stack.pop() if self._code.python_version >= (3, 12) and opcode.arg & 1: @@ -547,14 +549,14 @@ def byte_LOAD_ATTR(self, opcode): else: self._stack.push(self.load_attr(target_var, attr_name)) - def byte_LOAD_METHOD(self, opcode): + def byte_LOAD_METHOD(self, opcode) -> None: method_name = opcode.argval instance_var = self._stack.pop() (var1, var2) = self._load_method(instance_var, method_name) self._stack.push(var1) self._stack.push(var2) - def byte_IMPORT_NAME(self, opcode): + def byte_IMPORT_NAME(self, opcode) -> None: full_name = opcode.argval unused_level_var, fromlist = self._stack.popn(2) # The IMPORT_NAME for an "import a.b.c" will push the module "a". @@ -576,7 +578,7 @@ def byte_IMPORT_NAME(self, opcode): self._ctx.errorlog.import_error(self.stack, full_name) self._stack.push(module.to_variable()) - def byte_IMPORT_FROM(self, opcode): + def byte_IMPORT_FROM(self, opcode) -> None: attr_name = opcode.argval module = self._stack.top().get_atomic_value() attr = module.get_attribute(attr_name) @@ -589,11 +591,11 @@ def byte_IMPORT_FROM(self, opcode): # --------------------------------------------------------------- # Function and method calls - def byte_KW_NAMES(self, opcode): + def byte_KW_NAMES(self, opcode) -> None: # Stores a list of kw names to be retrieved by CALL self._call_helper.set_kw_names(opcode.argval) - def byte_CALL(self, opcode): + def byte_CALL(self, opcode) -> None: sentinel, *rest = self._stack.popn(opcode.arg + 2) if not sentinel.has_atomic_value(self._ctx.consts.singles['NULL']): raise NotImplementedError('CALL not fully implemented') @@ -601,25 +603,25 @@ def byte_CALL(self, opcode): callargs = self._call_helper.make_function_args(args) self._call_function(func, callargs) - def byte_CALL_FUNCTION(self, opcode): + def byte_CALL_FUNCTION(self, opcode) -> None: args = self._stack.popn(opcode.arg) func = self._stack.pop() callargs = self._call_helper.make_function_args(args) self._call_function(func, callargs) - def byte_CALL_FUNCTION_KW(self, opcode): + def byte_CALL_FUNCTION_KW(self, opcode) -> None: kwnames_var = self._stack.pop() args = self._stack.popn(opcode.arg) func = self._stack.pop() kwnames = [ - abstract.get_atomic_constant(key, str) - for key in abstract.get_atomic_constant(kwnames_var, tuple) + abstract.get_atomic_constant(key, str) # pytype: disable=wrong-arg-types + for key in abstract.get_atomic_constant(kwnames_var, tuple) # pytype: disable=wrong-arg-types ] self._call_helper.set_kw_names(kwnames) callargs = self._call_helper.make_function_args(args) self._call_function(func, callargs) - def byte_CALL_FUNCTION_EX(self, opcode): + def byte_CALL_FUNCTION_EX(self, opcode) -> None: if opcode.arg & _Flags.CALL_FUNCTION_EX_HAS_KWARGS: starstarargs = self._stack.pop() else: @@ -633,7 +635,7 @@ def byte_CALL_FUNCTION_EX(self, opcode): self._stack.pop_and_discard() self._call_function(func, callargs) - def byte_CALL_METHOD(self, opcode): + def byte_CALL_METHOD(self, opcode) -> None: args = self._stack.popn(opcode.arg) func = self._stack.pop() # pop the NULL off the stack (see LOAD_METHOD) @@ -644,79 +646,79 @@ def byte_CALL_METHOD(self, opcode): # Pytype tracks variables in enclosing scopes by name rather than emulating # the runtime's approach with cells and freevars, so we can ignore the opcodes # that deal with the latter. - def byte_MAKE_CELL(self, opcode): + def byte_MAKE_CELL(self, opcode) -> None: del opcode # unused - def byte_COPY_FREE_VARS(self, opcode): + def byte_COPY_FREE_VARS(self, opcode) -> None: del opcode # unused - def byte_LOAD_BUILD_CLASS(self, opcode): + def byte_LOAD_BUILD_CLASS(self, opcode) -> None: self._stack.push(self._ctx.consts.singles['__build_class__'].to_variable()) # --------------------------------------------------------------- # Operators - def unary_operator(self, name): + def unary_operator(self, name) -> None: x = self._stack.pop() f = self.load_attr(x, name) self._call_function(f, abstract.Args()) - def binary_operator(self, name): + def binary_operator(self, name) -> None: (x, y) = self._stack.popn(2) ret = operators.call_binary(self._ctx, name, x, y) self._stack.push(ret) - def inplace_operator(self, name): + def inplace_operator(self, name) -> None: (x, y) = self._stack.popn(2) ret = operators.call_inplace(self._ctx, self, name, x, y) self._stack.push(ret) - def byte_UNARY_NEGATIVE(self, opcode): + def byte_UNARY_NEGATIVE(self, opcode) -> None: self.unary_operator('__neg__') - def byte_UNARY_POSITIVE(self, opcode): + def byte_UNARY_POSITIVE(self, opcode) -> None: self.unary_operator('__pos__') - def byte_UNARY_INVERT(self, opcode): + def byte_UNARY_INVERT(self, opcode) -> None: self.unary_operator('__invert__') - def byte_BINARY_MATRIX_MULTIPLY(self, opcode): + def byte_BINARY_MATRIX_MULTIPLY(self, opcode) -> None: self.binary_operator('__matmul__') - def byte_BINARY_ADD(self, opcode): + def byte_BINARY_ADD(self, opcode) -> None: self.binary_operator('__add__') - def byte_BINARY_SUBTRACT(self, opcode): + def byte_BINARY_SUBTRACT(self, opcode) -> None: self.binary_operator('__sub__') - def byte_BINARY_MULTIPLY(self, opcode): + def byte_BINARY_MULTIPLY(self, opcode) -> None: self.binary_operator('__mul__') - def byte_BINARY_MODULO(self, opcode): + def byte_BINARY_MODULO(self, opcode) -> None: self.binary_operator('__mod__') - def byte_BINARY_LSHIFT(self, opcode): + def byte_BINARY_LSHIFT(self, opcode) -> None: self.binary_operator('__lshift__') - def byte_BINARY_RSHIFT(self, opcode): + def byte_BINARY_RSHIFT(self, opcode) -> None: self.binary_operator('__rshift__') - def byte_BINARY_AND(self, opcode): + def byte_BINARY_AND(self, opcode) -> None: self.binary_operator('__and__') - def byte_BINARY_XOR(self, opcode): + def byte_BINARY_XOR(self, opcode) -> None: self.binary_operator('__xor__') - def byte_BINARY_OR(self, opcode): + def byte_BINARY_OR(self, opcode) -> None: self.binary_operator('__or__') - def byte_BINARY_FLOOR_DIVIDE(self, opcode): + def byte_BINARY_FLOOR_DIVIDE(self, opcode) -> None: self.binary_operator('__floordiv__') - def byte_BINARY_TRUE_DIVIDE(self, opcode): + def byte_BINARY_TRUE_DIVIDE(self, opcode) -> None: self.binary_operator('__truediv__') - def byte_BINARY_POWER(self, opcode): + def byte_BINARY_POWER(self, opcode) -> None: self.binary_operator('__pow__') def byte_BINARY_SUBSCR(self, opcode): @@ -730,46 +732,46 @@ def byte_BINARY_SUBSCR(self, opcode): ret = obj.set_type_parameters(subscr_var) self._stack.push(ret.to_variable()) - def byte_INPLACE_MATRIX_MULTIPLY(self, opcode): + def byte_INPLACE_MATRIX_MULTIPLY(self, opcode) -> None: self.inplace_operator('__imatmul__') - def byte_INPLACE_ADD(self, opcode): + def byte_INPLACE_ADD(self, opcode) -> None: self.inplace_operator('__iadd__') - def byte_INPLACE_SUBTRACT(self, opcode): + def byte_INPLACE_SUBTRACT(self, opcode) -> None: self.inplace_operator('__isub__') - def byte_INPLACE_MULTIPLY(self, opcode): + def byte_INPLACE_MULTIPLY(self, opcode) -> None: self.inplace_operator('__imul__') - def byte_INPLACE_MODULO(self, opcode): + def byte_INPLACE_MODULO(self, opcode) -> None: self.inplace_operator('__imod__') - def byte_INPLACE_POWER(self, opcode): + def byte_INPLACE_POWER(self, opcode) -> None: self.inplace_operator('__ipow__') - def byte_INPLACE_LSHIFT(self, opcode): + def byte_INPLACE_LSHIFT(self, opcode) -> None: self.inplace_operator('__ilshift__') - def byte_INPLACE_RSHIFT(self, opcode): + def byte_INPLACE_RSHIFT(self, opcode) -> None: self.inplace_operator('__irshift__') - def byte_INPLACE_AND(self, opcode): + def byte_INPLACE_AND(self, opcode) -> None: self.inplace_operator('__iand__') - def byte_INPLACE_XOR(self, opcode): + def byte_INPLACE_XOR(self, opcode) -> None: self.inplace_operator('__ixor__') - def byte_INPLACE_OR(self, opcode): + def byte_INPLACE_OR(self, opcode) -> None: self.inplace_operator('__ior__') - def byte_INPLACE_FLOOR_DIVIDE(self, opcode): + def byte_INPLACE_FLOOR_DIVIDE(self, opcode) -> None: self.inplace_operator('__ifloordiv__') - def byte_INPLACE_TRUE_DIVIDE(self, opcode): + def byte_INPLACE_TRUE_DIVIDE(self, opcode) -> None: self.inplace_operator('__itruediv__') - def byte_BINARY_OP(self, opcode): + def byte_BINARY_OP(self, opcode) -> None: """Implementation of BINARY_OP opcode.""" # Python 3.11 unified a lot of BINARY_* and INPLACE_* opcodes into a single # BINARY_OP. The underlying operations remain unchanged, so we can just @@ -820,36 +822,36 @@ def _build_collection_from_stack( constant = factory(self._ctx, typ(elements)) self._stack.push(constant.to_variable()) - def byte_BUILD_TUPLE(self, opcode): + def byte_BUILD_TUPLE(self, opcode) -> None: self._build_collection_from_stack(opcode, tuple, factory=abstract.Tuple) - def byte_BUILD_LIST(self, opcode): + def byte_BUILD_LIST(self, opcode) -> None: self._build_collection_from_stack(opcode, list, factory=abstract.List) - def byte_BUILD_SET(self, opcode): + def byte_BUILD_SET(self, opcode) -> None: self._build_collection_from_stack(opcode, set, factory=abstract.Set) - def byte_BUILD_MAP(self, opcode): + def byte_BUILD_MAP(self, opcode) -> None: n_elts = opcode.arg args = self._stack.popn(2 * n_elts) ret = {args[2 * i]: args[2 * i + 1] for i in range(n_elts)} ret = abstract.Dict(self._ctx, ret) self._stack.push(ret.to_variable()) - def byte_BUILD_CONST_KEY_MAP(self, opcode): + def byte_BUILD_CONST_KEY_MAP(self, opcode) -> None: n_elts = opcode.arg keys = self._stack.pop() # Note that `keys` is a tuple of raw python values; we do not convert them # to abstract objects because they are used internally to construct function # call args. - keys = abstract.get_atomic_constant(keys, tuple) + keys = abstract.get_atomic_constant(keys, tuple) # pytype: disable=wrong-arg-types assert len(keys) == n_elts vals = self._stack.popn(n_elts) ret = dict(zip(keys, vals)) ret = abstract.Dict(self._ctx, ret) self._stack.push(ret.to_variable()) - def byte_LIST_APPEND(self, opcode): + def byte_LIST_APPEND(self, opcode) -> None: # Used by the compiler e.g. for [x for x in ...] count = opcode.arg val = self._stack.pop() @@ -860,7 +862,7 @@ def byte_LIST_APPEND(self, opcode): target = target_var.get_atomic_value() self._replace_atomic_stack_value(count, target.append(val)) - def byte_SET_ADD(self, opcode): + def byte_SET_ADD(self, opcode) -> None: # Used by the compiler e.g. for {x for x in ...} count = opcode.arg val = self._stack.pop() @@ -868,7 +870,7 @@ def byte_SET_ADD(self, opcode): target = target_var.get_atomic_value() self._replace_atomic_stack_value(count, target.add(val)) - def byte_MAP_ADD(self, opcode): + def byte_MAP_ADD(self, opcode) -> None: # Used by the compiler e.g. for {x, y for x, y in ...} count = opcode.arg # The value is at the top of the stack, followed by the key. @@ -892,7 +894,7 @@ def _unpack_list_extension(self, var: _Var) -> abstract.List: self._ctx, [abstract.Splat(self._ctx, val).to_variable()] ) - def byte_LIST_EXTEND(self, opcode): + def byte_LIST_EXTEND(self, opcode) -> None: count = opcode.arg update_var = self._stack.pop() update = self._unpack_list_extension(update_var) @@ -919,11 +921,11 @@ def _unpack_dict_update(self, var: _Var) -> abstract.Dict | None: else: raise ValueError('Unexpected dict update:', val) - def byte_DICT_MERGE(self, opcode): + def byte_DICT_MERGE(self, opcode) -> None: # DICT_MERGE is like DICT_UPDATE but raises an exception for duplicate keys. self.byte_DICT_UPDATE(opcode) - def byte_DICT_UPDATE(self, opcode): + def byte_DICT_UPDATE(self, opcode) -> None: count = opcode.arg update_var = self._stack.pop() update = self._unpack_dict_update(update_var) @@ -939,13 +941,13 @@ def byte_DICT_UPDATE(self, opcode): self._replace_atomic_stack_value(count, ret) def _list_to_tuple(self, var: _Var) -> _Var: - target = abstract.get_atomic_constant(var, list) + target = abstract.get_atomic_constant(var, list) # pytype: disable=wrong-arg-types return abstract.Tuple(self._ctx, tuple(target)).to_variable() - def byte_LIST_TO_TUPLE(self, opcode): + def byte_LIST_TO_TUPLE(self, opcode) -> None: self._stack.push(self._list_to_tuple(self._stack.pop())) - def byte_FORMAT_VALUE(self, opcode): + def byte_FORMAT_VALUE(self, opcode) -> None: if opcode.arg & pyc_marshal.Flags.FVS_MASK: self._stack.pop_and_discard() # FORMAT_VALUE pops, formats and pushes back a string, so we just need to @@ -954,7 +956,7 @@ def byte_FORMAT_VALUE(self, opcode): ret = self._ctx.types[str].instantiate().to_variable() self._stack.push(ret) - def byte_BUILD_STRING(self, opcode): + def byte_BUILD_STRING(self, opcode) -> None: # Pop n arguments off the stack and build a string out of them self._stack.popn(opcode.arg) ret = self._ctx.types[str].instantiate().to_variable() @@ -963,51 +965,51 @@ def byte_BUILD_STRING(self, opcode): # --------------------------------------------------------------- # Branches and jumps - def byte_POP_JUMP_FORWARD_IF_FALSE(self, opcode): + def byte_POP_JUMP_FORWARD_IF_FALSE(self, opcode) -> None: self._pop_jump_if_false(opcode) - def byte_POP_JUMP_IF_FALSE(self, opcode): + def byte_POP_JUMP_IF_FALSE(self, opcode) -> None: self._pop_jump_if_false(opcode) - def byte_JUMP_FORWARD(self, opcode): + def byte_JUMP_FORWARD(self, opcode) -> None: self._merge_state_into(self._current_state, opcode.argval) # --------------------------------------------------------------- # Stack manipulation - def byte_POP_TOP(self, opcode): + def byte_POP_TOP(self, opcode) -> None: del opcode # unused self._stack.pop_and_discard() - def byte_DUP_TOP(self, opcode): + def byte_DUP_TOP(self, opcode) -> None: del opcode # unused self._stack.push(self._stack.top()) - def byte_DUP_TOP_TWO(self, opcode): + def byte_DUP_TOP_TWO(self, opcode) -> None: del opcode # unused a, b = self._stack.popn(2) for v in (a, b, a, b): self._stack.push(v) - def byte_ROT_TWO(self, opcode): + def byte_ROT_TWO(self, opcode) -> None: del opcode # unused self._stack.rotn(2) - def byte_ROT_THREE(self, opcode): + def byte_ROT_THREE(self, opcode) -> None: del opcode # unused self._stack.rotn(3) - def byte_ROT_FOUR(self, opcode): + def byte_ROT_FOUR(self, opcode) -> None: del opcode # unused self._stack.rotn(4) - def byte_ROT_N(self, opcode): + def byte_ROT_N(self, opcode) -> None: self._stack.rotn(opcode.arg) # --------------------------------------------------------------- # Intrinsic function calls - def _call_intrinsic(self, opcode): + def _call_intrinsic(self, opcode) -> None: try: intrinsic_impl = getattr(self, f'byte_intrinsic_{opcode.argval}') except AttributeError as e: @@ -1016,11 +1018,11 @@ def _call_intrinsic(self, opcode): ) from e intrinsic_impl() - def byte_CALL_INTRINSIC_1(self, opcode): + def byte_CALL_INTRINSIC_1(self, opcode) -> None: self._call_intrinsic(opcode) - def byte_CALL_INTRINSIC_2(self, opcode): + def byte_CALL_INTRINSIC_2(self, opcode) -> None: self._call_intrinsic(opcode) - def byte_intrinsic_INTRINSIC_LIST_TO_TUPLE(self): + def byte_intrinsic_INTRINSIC_LIST_TO_TUPLE(self) -> None: self._stack.push(self._list_to_tuple(self._stack.pop())) diff --git a/pytype/rewrite/operators.py b/pytype/rewrite/operators.py index c1df4bee9..f8b9d6e09 100644 --- a/pytype/rewrite/operators.py +++ b/pytype/rewrite/operators.py @@ -9,7 +9,7 @@ _Var = variables.Variable[abstract.BaseValue] _Binding = variables.Binding[abstract.BaseValue] -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) def call_binary( diff --git a/pytype/rewrite/output.py b/pytype/rewrite/output.py index 4033a4002..5fc9eed79 100644 --- a/pytype/rewrite/output.py +++ b/pytype/rewrite/output.py @@ -5,7 +5,7 @@ from pytype.pytd import pytd_utils from pytype.rewrite.abstract import abstract -_IGNORED_CLASS_ATTRIBUTES = frozenset([ +_IGNORED_CLASS_ATTRIBUTES: frozenset[str] = frozenset([ '__module__', '__qualname__', ]) @@ -48,7 +48,7 @@ def _class_to_pytd_def(self, val: abstract.SimpleClass) -> pytd.Class: if member_name in _IGNORED_CLASS_ATTRIBUTES: continue if isinstance(member_val, abstract.SimpleFunction): - member_val = member_val.bind_to(instance) + member_val = member_val.bind_to(instance) # pytype: disable=attribute-error try: member_type = self.to_pytd_def(member_val) except NotImplementedError: @@ -193,10 +193,10 @@ def to_pytd_type(self, val: abstract.BaseValue) -> pytd.Type: elif isinstance(val, abstract.BaseInstance): return pytd.NamedType(val.cls.name) elif isinstance(val, (abstract.BaseFunction, abstract.BoundFunction)): - if len(val.signatures) > 1: + if len(val.signatures) > 1: # pytype: disable=attribute-error fixed_length_posargs_only = False else: - sig = val.signatures[0] + sig = val.signatures[0] # pytype: disable=attribute-error fixed_length_posargs_only = ( not sig.defaults and not sig.varargs_name @@ -212,7 +212,7 @@ def to_pytd_type(self, val: abstract.BaseValue) -> pytd.Type: ) else: ret = abstract.join_values( - self._ctx, [frame.get_return_value() for frame in val.analyze()] + self._ctx, [frame.get_return_value() for frame in val.analyze()] # pytype: disable=attribute-error ) return pytd.GenericType( base_type=pytd.NamedType('typing.Callable'), diff --git a/pytype/rewrite/overlays/overlays.py b/pytype/rewrite/overlays/overlays.py index ec0171965..f6f1796cb 100644 --- a/pytype/rewrite/overlays/overlays.py +++ b/pytype/rewrite/overlays/overlays.py @@ -35,7 +35,7 @@ def register(transformer: _ClsTransformFuncT) -> _ClsTransformFuncT: return register -def initialize(): +def initialize() -> None: # Imports overlay implementations so that ther @register_* decorators execute # and populate the overlay registry. # pylint: disable=g-import-not-at-top,unused-import diff --git a/pytype/rewrite/stack.py b/pytype/rewrite/stack.py index 6e57adea2..6287ba471 100644 --- a/pytype/rewrite/stack.py +++ b/pytype/rewrite/stack.py @@ -12,7 +12,7 @@ class DataStack: """Data stack.""" - def __init__(self): + def __init__(self) -> None: self._stack: list[_Var] = [] def push(self, var: _Var) -> None: @@ -64,11 +64,11 @@ def _stack_size_error(self, msg): msg = f'Trying to {msg} in a stack of size {len(self._stack)}' raise IndexError(msg) - def __bool__(self): + def __bool__(self) -> bool: return bool(self._stack) - def __len__(self): + def __len__(self) -> int: return len(self._stack) - def __repr__(self): + def __repr__(self) -> str: return f'DataStack{self._stack}' diff --git a/pytype/rewrite/tests/test_utils.py b/pytype/rewrite/tests/test_utils.py index b420442e9..76dfe11e9 100644 --- a/pytype/rewrite/tests/test_utils.py +++ b/pytype/rewrite/tests/test_utils.py @@ -1,6 +1,6 @@ """Test utilities.""" -from collections.abc import Sequence +from collections.abc import Sequence, Callable import re import sys import textwrap @@ -17,7 +17,7 @@ class ContextfulTestBase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.ctx = context.Context(src='') @@ -52,23 +52,23 @@ def __init__(self, ops: Sequence[list[opcodes.Opcode]], consts=()): # pylint: disable=invalid-name # Use camel-case to match the unittest.skip* methods. -def skipIfPy(*versions, reason): +def skipIfPy(*versions, reason) -> Callable[[Callable], Callable]: return unittest.skipIf(sys.version_info[:2] in versions, reason) -def skipUnlessPy(*versions, reason): +def skipUnlessPy(*versions, reason) -> Callable[[Callable], Callable]: return unittest.skipUnless(sys.version_info[:2] in versions, reason) -def skipBeforePy(version, reason): +def skipBeforePy(version, reason) -> Callable[[Callable], Callable]: return unittest.skipIf(sys.version_info[:2] < version, reason) -def skipFromPy(version, reason): +def skipFromPy(version, reason) -> Callable[[Callable], Callable]: return unittest.skipUnless(sys.version_info[:2] < version, reason) -def skipOnWin32(reason): +def skipOnWin32(reason) -> Callable[[Callable], Callable]: return unittest.skipIf(sys.platform == 'win32', reason) diff --git a/pytype/rewrite/vm.py b/pytype/rewrite/vm.py index 31ccc0b1d..65805d8d9 100644 --- a/pytype/rewrite/vm.py +++ b/pytype/rewrite/vm.py @@ -2,7 +2,7 @@ from collections.abc import Sequence import logging - +from typing import TypeVar from pytype import config from pytype.blocks import blocks from pytype.pyc import pyc @@ -12,7 +12,9 @@ from pytype.rewrite import frame as frame_lib from pytype.rewrite.abstract import abstract -log = logging.getLogger(__name__) +_TVirtualMachine = TypeVar('_TVirtualMachine', bound='VirtualMachine') + +log: logging.Logger = logging.getLogger(__name__) class VirtualMachine: diff --git a/pytype/state.py b/pytype/state.py index bd2b9957c..71079b84f 100644 --- a/pytype/state.py +++ b/pytype/state.py @@ -14,18 +14,18 @@ from pytype.blocks import blocks from pytype.typegraph import cfg -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) # A special constant, returned by split_conditions() to signal that the # condition cannot be satisfied with any known bindings. -UNSATISFIABLE = object() +UNSATISFIABLE: Any = object() # Represents the `not None` condition for restrict_condition(). -NOT_NONE = object() +NOT_NONE: Any = object() FrameType = Union["SimpleFrame", "Frame"] # This should be context.Context, which can't be imported due to a circular dep. -_ContextType = Any +_ContextType: Any = Any class FrameState(utils.ContextWeakrefMixin): @@ -33,7 +33,9 @@ class FrameState(utils.ContextWeakrefMixin): __slots__ = ["block_stack", "data_stack", "node", "exception", "why"] - def __init__(self, data_stack, block_stack, node, ctx, exception, why): + def __init__( + self, data_stack, block_stack, node, ctx, exception, why + ): super().__init__(ctx) self.data_stack = data_stack self.block_stack = block_stack @@ -85,7 +87,7 @@ def topn(self, n): else: return () - def pop(self): + def pop(self) -> tuple[Any, Any]: """Pop a value from the value stack.""" if not self.data_stack: raise IndexError("Trying to pop from an empty stack") @@ -96,7 +98,7 @@ def pop_and_discard(self): """Pop a value from the value stack and discard it.""" return self.set_stack(self.data_stack[:-1]) - def popn(self, n): + def popn(self, n) -> tuple[Any, Any]: """Return n values, ordered oldest-to-newest.""" if not n: # Not an error: E.g. function calls with no parameters pop zero items @@ -245,7 +247,7 @@ class SimpleFrame: error logging. """ - def __init__(self, opcode=None, node=None, f_globals=None): + def __init__(self, opcode=None, node=None, f_globals=None) -> None: self.f_code = None # for recursion detection self.f_builtins = None self.f_globals = f_globals @@ -413,14 +415,14 @@ def __init__( str, list[abstract.InterpreterFunction] ] = collections.defaultdict(list) - def __repr__(self): # pragma: no cover + def __repr__(self) -> str: # pragma: no cover return "" % ( id(self), self.f_code.filename, self.f_lineno, ) - def copy_free_vars(self, n): + def copy_free_vars(self, n) -> None: offset = len(self.cells) - len(self.f_code.freevars) for i in range(n): self.cells[i + offset] = self.closure[i] @@ -447,7 +449,7 @@ class Condition: binding: A Binding for the condition's constraints. """ - def __init__(self, node, dnf): + def __init__(self, node, dnf) -> None: # The condition is represented by a dummy variable with a single binding # to None. The origins for this binding are the dnf clauses. self._var = node.program.NewVariable() @@ -461,7 +463,7 @@ def binding(self): return self._binding -_restrict_counter = metrics.MapCounter("state_restrict") +_restrict_counter: metrics.MapCounter = metrics.MapCounter("state_restrict") def _match_condition(value, condition): diff --git a/pytype/tests/test_base.py b/pytype/tests/test_base.py index 48a1a9a5c..1e618901e 100644 --- a/pytype/tests/test_base.py +++ b/pytype/tests/test_base.py @@ -4,6 +4,7 @@ import logging import sys import textwrap +from typing import Any # from absl import flags from pytype import analyze @@ -21,10 +22,12 @@ from pytype.pytd import visitors from pytype.rewrite import analyze as rewrite_analyze from pytype.tests import test_utils +from pytype.tests.test_utils import ErrorMatcher import unittest -log = logging.getLogger(__name__) + +log: logging.Logger = logging.getLogger(__name__) # Make this false if you need to run the debugger inside a test. @@ -52,7 +55,7 @@ def _MatchLoaderConfig(options, loader): return options == loader.options -def _Format(code): +def _Format(code) -> str: # Removes the leading newline introduced by writing, e.g., # self.Check(""" # code @@ -65,7 +68,7 @@ def _Format(code): class UnitTest(unittest.TestCase): """Base class for tests that specify a target Python version.""" - python_version = sys.version_info[:2] + python_version: tuple[int, int] = sys.version_info[:2] class BaseTest(unittest.TestCase): @@ -75,13 +78,13 @@ class BaseTest(unittest.TestCase): python_version: tuple[int, int] = sys.version_info[:2] @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: super().setUpClass() # We use class-wide loader to avoid creating a new loader for every test # method if not required. cls._loader = None - def setUp(self): + def setUp(self) -> None: super().setUp() self.options = config.Options.create( python_version=self.python_version, @@ -109,13 +112,13 @@ def loader(self): def analyze_lib(self): return rewrite_analyze if self.options.use_rewrite else analyze - def ConfigureOptions(self, **kwargs): + def ConfigureOptions(self, **kwargs) -> None: assert ( "python_version" not in kwargs ), "Individual tests cannot set the python_version of the config options." self.options.tweak(**kwargs) - def _GetPythonpathArgs(self, pythonpath, imports_map): + def _GetPythonpathArgs(self, pythonpath, imports_map) -> dict[str, Any]: """Gets values for --pythonpath and --imports_map.""" if pythonpath: pythonpath_arg = pythonpath @@ -139,7 +142,7 @@ def Check( quick=False, imports_map=None, **kwargs, - ): + ) -> None: """Run an inference smoke test for the given code.""" self.ConfigureOptions( skip_repeat_calls=skip_repeat_calls, @@ -160,12 +163,12 @@ def Check( errorlog.print_to_stderr() self.fail(f"Checker found {len(errorlog)} errors:\n{errorlog}") - def assertNoCrash(self, method, code, **kwargs): + def assertNoCrash(self, method, code, **kwargs) -> None: method(code, report_errors=False, **kwargs) def _SetUpErrorHandling( self, code, pythonpath, analyze_annotated, quick, imports_map - ): + ) -> dict[str, Any]: code = _Format(code) self.ConfigureOptions( analyze_annotated=analyze_annotated, @@ -183,7 +186,7 @@ def InferWithErrors( quick=False, imports_map=None, **kwargs, - ): + ) -> tuple[Any, ErrorMatcher]: """Runs inference on code expected to have type errors.""" kwargs.update( self._SetUpErrorHandling( @@ -217,7 +220,7 @@ def CheckWithErrors( quick=False, imports_map=None, **kwargs, - ): + ) -> ErrorMatcher: """Check and match errors.""" kwargs.update( self._SetUpErrorHandling( @@ -252,24 +255,24 @@ def InferFromFile(self, filename, pythonpath): unit.Visit(visitors.VerifyVisitor()) return pytd_utils.CanonicalOrdering(unit) - def assertErrorRegexes(self, matcher, expected_errors): + def assertErrorRegexes(self, matcher, expected_errors) -> None: matcher.assert_error_regexes(expected_errors) - def assertErrorSequences(self, matcher, expected_errors): + def assertErrorSequences(self, matcher, expected_errors) -> None: matcher.assert_error_sequences(expected_errors) - def assertDiagnosticRegexes(self, matcher, expected_errors): + def assertDiagnosticRegexes(self, matcher, expected_errors) -> None: matcher.assert_diagnostic_regexes(expected_errors) - def assertDiagnosticMessages(self, matcher, expected_errors): + def assertDiagnosticMessages(self, matcher, expected_errors) -> None: matcher.assert_diagnostic_messages(expected_errors) - def _PickleAst(self, ast, module_name): + def _PickleAst(self, ast, module_name) -> bytes: assert module_name ast = serialize_ast.PrepareForExport(module_name, ast, self.loader) return pickle_utils.Serialize(ast) - def _PickleSource(self, src, module_name): + def _PickleSource(self, src, module_name) -> bytes: ast = serialize_ast.SourceToExportableAst( module_name, textwrap.dedent(src), self.loader ) @@ -318,7 +321,7 @@ def _InferAndVerify( imports_map=None, quick=False, **kwargs, - ): + ) -> tuple[Any, Any]: """Infer types for the source code treating it as a module. Used by Infer(). @@ -360,7 +363,7 @@ def _InferAndVerify( self.fail(f"Inferencer found {len(errorlog)} errors:\n{errorlog}") return unit, ret.ast_deps - def assertTypesMatchPytd(self, ty, pytd_src): + def assertTypesMatchPytd(self, ty, pytd_src) -> None: """Parses pytd_src and compares with ty.""" pytd_tree = parser.parse_string( textwrap.dedent(pytd_src), @@ -435,12 +438,12 @@ def DepTree(self, deps): ) -def _PrintErrorDebug(descr, value): +def _PrintErrorDebug(descr, value) -> None: log.error("=============== %s ===========", descr) _LogLines(log.error, value) log.error("=========== end %s ===========", descr) -def _LogLines(log_cmd, lines): +def _LogLines(log_cmd, lines) -> None: for l in lines.split("\n"): log_cmd("%s", l) diff --git a/pytype/tests/test_utils.py b/pytype/tests/test_utils.py index 664350f8b..9e9131ae9 100644 --- a/pytype/tests/test_utils.py +++ b/pytype/tests/test_utils.py @@ -1,6 +1,7 @@ """Utility class and function for tests.""" import collections +from collections.abc import Callable import copy import dataclasses import io @@ -11,6 +12,7 @@ import sys import textwrap import tokenize +from typing import Any, TypeVar import pycnite.mapping import pycnite.types @@ -25,24 +27,27 @@ from pytype.platform_utils import path_utils from pytype.platform_utils import tempfile as compatible_tempfile from pytype.pytd import slots +from pytype.state import SimpleFrame import unittest +_TTempdir = TypeVar("_TTempdir", bound="Tempdir") + class Tempdir: """Context handler for creating temporary directories.""" - def __enter__(self): + def __enter__(self: _TTempdir) -> _TTempdir: self.path = compatible_tempfile.mkdtemp() return self - def create_directory(self, filename): + def create_directory(self, filename) -> str: """Create a subdirectory in the temporary directory.""" path = path_utils.join(self.path, filename) makedirs(path) return path - def create_file(self, filename, indented_data=None): + def create_file(self, filename, indented_data=None) -> str: """Create a file in the temporary directory. Dedents the data if needed.""" filedir, filename = path_utils.split(filename) if filedir: @@ -60,14 +65,14 @@ def create_file(self, filename, indented_data=None): fi.write(data) return path - def delete_file(self, filename): + def delete_file(self, filename) -> None: os.unlink(path_utils.join(self.path, filename)) - def __exit__(self, error_type, value, tb): + def __exit__(self, error_type: None, value: None, tb: None) -> bool: shutil.rmtree(path=self.path) return False # reraise any exceptions - def __getitem__(self, filename): + def __getitem__(self, filename) -> str: """Get the full path for an entry in this directory.""" return path_utils.join(self.path, filename) @@ -81,7 +86,7 @@ class FakeCode: class FakeOpcode: """Util class for generating fake Opcode for testing.""" - def __init__(self, filename, line, endline, col, endcol, methodname): + def __init__(self, filename, line, endline, col, endcol, methodname) -> None: self.code = FakeCode(filename, methodname) self.line = line self.endline = endline @@ -89,7 +94,7 @@ def __init__(self, filename, line, endline, col, endcol, methodname): self.endcol = endcol self.name = "FAKE_OPCODE" - def to_stack(self): + def to_stack(self) -> list[SimpleFrame]: return [frame_state.SimpleFrame(self)] @@ -105,7 +110,7 @@ def fake_stack(length): class FakePrettyPrinter(pretty_printer_base.PrettyPrinterBase): """Fake pretty printer for constructing an error log.""" - def __init__(self): + def __init__(self) -> None: options = config.Options.create() super().__init__(make_context(options)) @@ -133,7 +138,7 @@ class OperatorsTestMixin: _HAS_DYNAMIC_ATTRIBUTES = True - def check_expr(self, expr, assignments, expected_return): + def check_expr(self, expr, assignments, expected_return) -> None: """Check the expression.""" # Note that testing "1+2" as opposed to "x=1; y=2; x+y" doesn't really test # anything because the peephole optimizer converts "1+2" to "3" and __add__ @@ -151,7 +156,7 @@ def f(): ty = self.Infer(src) self.assertTypesMatchPytd(ty, f"def f() -> {expected_return}: ...") - def check_binary(self, function_name, op): + def check_binary(self, function_name, op) -> None: """Check the binary operator.""" ty = self.Infer(f""" class Foo: @@ -173,7 +178,7 @@ def f() -> complex: ... """, ) - def check_unary(self, function_name, op, ret=None): + def check_unary(self, function_name, op, ret=None) -> None: """Check the unary operator.""" ty = self.Infer(f""" class Foo: @@ -192,7 +197,7 @@ def f() -> {ret or "complex"}: ... """, ) - def check_reverse(self, function_name, op): + def check_reverse(self, function_name, op) -> None: """Check the reverse operator.""" ty = self.Infer(f""" class Foo: @@ -225,7 +230,7 @@ def i() -> complex: ... """, ) - def check_inplace(self, function_name, op): + def check_inplace(self, function_name, op) -> None: """Check the inplace operator.""" ty = self.Infer(f""" class Foo: @@ -252,7 +257,7 @@ class InplaceTestMixin: _HAS_DYNAMIC_ATTRIBUTES = True - def _check_inplace(self, op, assignments, expected_return): + def _check_inplace(self, op, assignments, expected_return) -> None: """Check the inplace operator.""" assignments = "; ".join(assignments) src = f""" @@ -270,7 +275,7 @@ class TestCollectionsMixin: _HAS_DYNAMIC_ATTRIBUTES = True - def _testCollectionsObject(self, obj, good_arg, bad_arg, error): # pylint: disable=invalid-name + def _testCollectionsObject(self, obj, good_arg, bad_arg, error) -> None: # pylint: disable=invalid-name result = self.CheckWithErrors(f""" import collections def f(x: collections.{obj}): ... @@ -285,7 +290,7 @@ class MakeCodeMixin: _HAS_DYNAMIC_ATTRIBUTES = True - def make_code(self, int_array, name="testcode"): + def make_code(self, int_array, name="testcode") -> pycnite.types.CodeType38: """Utility method for creating CodeType objects.""" return pycnite.types.CodeType38( co_argcount=0, @@ -311,23 +316,23 @@ def make_code(self, int_array, name="testcode"): class RegexMatcher: """Match a regex.""" - def __init__(self, regex): + def __init__(self, regex) -> None: self.regex = regex def match(self, message): return re.search(self.regex, message, flags=re.DOTALL) - def __repr__(self): + def __repr__(self) -> str: return repr(self.regex) class SequenceMatcher: """Match a sequence of substrings in order.""" - def __init__(self, seq): + def __init__(self, seq) -> None: self.seq = seq - def match(self, message): + def match(self, message) -> bool: start = 0 for s in self.seq: i = message.find(s, start) @@ -336,7 +341,7 @@ def match(self, message): start = i + len(s) return True - def __repr__(self): + def __repr__(self) -> str: return repr(self.seq) @@ -357,12 +362,12 @@ class ErrorMatcher: See tests/test_base_test.py for usage examples. """ - ERROR_RE = re.compile( + ERROR_RE: re.Pattern[str] = re.compile( r"^(?P(\w+-)+\w+)(\[(?P.+)\])?" r"((?P([!=]=|[<>]=?))(?P\d+\.\d+))?$" ) - def __init__(self, src): + def __init__(self, src) -> None: # errorlog and marks are set by assert_errors_match_expected() self.errorlog = None self.marks = None @@ -376,7 +381,7 @@ def _fail(self, msg): def has_error(self): return self.errorlog and self.errorlog.has_error() - def assert_errors_match_expected(self, errorlog): + def assert_errors_match_expected(self, errorlog) -> None: """Matches expected errors against the errorlog, populating self.marks.""" def _format_error(line, code, mark=None): @@ -415,7 +420,7 @@ def _format_error(line, code, mark=None): if leftover_errors: self._fail("Errors not found:\n" + "\n".join(leftover_errors)) - def _assert_error_messages(self, matchers): + def _assert_error_messages(self, matchers) -> None: """Assert error messages.""" assert self.marks is not None for mark, error in self.marks.items(): @@ -431,7 +436,7 @@ def _assert_error_messages(self, matchers): if matchers: self._fail(f"Marks not found in code: {', '.join(matchers)}") - def assert_diagnostic_messages(self, matchers): + def assert_diagnostic_messages(self, matchers) -> None: """Assert error messages.""" assert self.marks is not None for mark, error in self.marks.items(): @@ -452,21 +457,21 @@ def assert_diagnostic_messages(self, matchers): if matchers: self._fail(f"Marks not found in code: {', '.join(matchers)}") - def assert_error_regexes(self, expected_regexes): + def assert_error_regexes(self, expected_regexes) -> None: matchers = {k: RegexMatcher(v) for k, v in expected_regexes.items()} self._assert_error_messages(matchers) - def assert_error_sequences(self, expected_sequences): + def assert_error_sequences(self, expected_sequences) -> None: matchers = {k: SequenceMatcher(v) for k, v in expected_sequences.items()} self._assert_error_messages(matchers) - def assert_diagnostic_regexes(self, expected_diagnostic_regexes): + def assert_diagnostic_regexes(self, expected_diagnostic_regexes) -> None: matchers = { k: RegexMatcher(v) for k, v in expected_diagnostic_regexes.items() } self.assert_diagnostic_messages(matchers) - def _parse_comment(self, comment): + def _parse_comment(self, comment) -> tuple[Any, Any] | None: comment = comment.strip() error_match = self.ERROR_RE.fullmatch(comment) if not error_match: @@ -479,7 +484,7 @@ def _parse_comment(self, comment): return None return error_match.group("code"), error_match.group("mark") - def _parse_comments(self, src): + def _parse_comments(self, src) -> collections.defaultdict[int, Any]: """Parse comments.""" src = io.StringIO(src) expected = collections.defaultdict(list) @@ -515,34 +520,34 @@ class Py310Opcodes: # pylint: disable=invalid-name # Use camel-case to match the unittest.skip* methods. -def skipIfPy(*versions, reason): +def skipIfPy(*versions, reason) -> Callable[[Callable], Callable]: return unittest.skipIf(sys.version_info[:2] in versions, reason) -def skipUnlessPy(*versions, reason): +def skipUnlessPy(*versions, reason) -> Callable[[Callable], Callable]: return unittest.skipUnless(sys.version_info[:2] in versions, reason) -def skipBeforePy(version, reason): +def skipBeforePy(version, reason) -> Callable[[Callable], Callable]: return unittest.skipIf(sys.version_info[:2] < version, reason) -def skipFromPy(version, reason): +def skipFromPy(version, reason) -> Callable[[Callable], Callable]: return unittest.skipUnless(sys.version_info[:2] < version, reason) -def skipOnWin32(reason): +def skipOnWin32(reason) -> Callable[[Callable], Callable]: return unittest.skipIf(sys.platform == "win32", reason) -def make_context(options, src=""): +def make_context(options, src="") -> context.Context: """Create a minimal context for tests.""" return context.Context( options=options, loader=load_pytd.Loader(options), src=src ) -def test_data_file(filename): +def test_data_file(filename) -> str: pytype_dir = path_utils.dirname(path_utils.dirname(path_utils.__file__)) code = path_utils.join( pytype_dir, file_utils.replace_separator("test_data/"), filename diff --git a/pytype/tools/analyze_project/config.py b/pytype/tools/analyze_project/config.py index 5269d6084..a14966369 100644 --- a/pytype/tools/analyze_project/config.py +++ b/pytype/tools/analyze_project/config.py @@ -7,7 +7,7 @@ import os import sys import textwrap -from typing import Any +from typing import Any, TypeVar, Union from pytype import config as pytype_config from pytype import file_utils @@ -15,6 +15,9 @@ from pytype.platform_utils import path_utils from pytype.tools import config + +_T0 = TypeVar('_T0') + _TOML = '.toml' @@ -48,7 +51,7 @@ class Item: # Generates both the default config and the sample config file. These items # don't have ArgInfo populated, as it is needed only for pytype-single args. -ITEMS = { +ITEMS: dict[str, Item] = { 'exclude': Item( '', '**/*_test.py **/test_*.py', None, 'Space-separated list of files or directories to exclude.'), @@ -76,7 +79,7 @@ class Item: } -REPORT_ERRORS_ITEMS = { +REPORT_ERRORS_ITEMS: dict[str, Item] = { 'disable': Item( None, 'pyi-error', ArgInfo('--disable', ','.join), 'Space-separated list of error names to ignore.'), @@ -86,7 +89,7 @@ class Item: # The missing fields will be filled in by generate_sample_config_or_die. -def _pytype_single_items(): +def _pytype_single_items() -> dict[str, Item]: """Args to pass through to pytype_single.""" out = {} flags = pytype_config.FEATURE_FLAGS + pytype_config.EXPERIMENTAL_FLAGS @@ -99,10 +102,10 @@ def _pytype_single_items(): return out -_PYTYPE_SINGLE_ITEMS = _pytype_single_items() +_PYTYPE_SINGLE_ITEMS: dict[str, Item] = _pytype_single_items() -def get_pytype_single_item(name): +def get_pytype_single_item(name) -> Item: # We want to avoid exposing this hard-coded list as much as possible so that # parser.pytype_single_args, which is guaranteed to match the actual args, is # used instead. @@ -113,19 +116,19 @@ def string_to_bool(s): return s == 'True' if s in ('True', 'False') else s -def concat_disabled_rules(s): +def concat_disabled_rules(s) -> str: return ','.join(t for t in s.split() if t) -def get_platform(p): +def get_platform(p: _T0) -> Union[str, _T0]: return p or sys.platform -def get_python_version(v): +def get_python_version(v: _T0) -> Union[str, _T0]: return v or utils.format_version(sys.version_info[:2]) -def parse_jobs(s): +def parse_jobs(s) -> int: """Parse the --jobs option.""" if s == 'auto': try: @@ -137,7 +140,7 @@ def parse_jobs(s): return int(s) -def make_converters(cwd=None): +def make_converters(cwd=None) -> dict[str, Callable[[Any], Any]]: """For items that need coaxing into their internal representations.""" return { 'disable': concat_disabled_rules, @@ -152,14 +155,14 @@ def make_converters(cwd=None): } -def _toml_format(v): +def _toml_format(v) -> str: try: return str(int(v)) except ValueError: return str(v).lower() if v in ('True', 'False') else repr(v) -def _make_path_formatter(ext): +def _make_path_formatter(ext) -> Callable[[Any], Any]: """Formatter for a string of paths.""" def format_path(p): paths = p.split() @@ -170,7 +173,7 @@ def format_path(p): return format_path -def make_formatters(ext): +def make_formatters(ext) -> dict[str, Any]: return { 'disable': _make_path_formatter(ext), 'exclude': _make_path_formatter(ext), @@ -208,7 +211,7 @@ def __str__(self): class FileConfig(argparse.Namespace): """Configuration variables from a file.""" - def read_from_file(self, filepath): + def read_from_file(self, filepath: _T0) -> _T0 | None: """Read config from the pytype section of a configuration file.""" _, ext = os.path.splitext(filepath) @@ -227,7 +230,7 @@ def read_from_file(self, filepath): return filepath -def generate_sample_config_or_die(filename, pytype_single_args): +def generate_sample_config_or_die(filename, pytype_single_args) -> None: """Write out a sample config file.""" if path_utils.exists(filename): @@ -271,7 +274,7 @@ def generate_sample_config_or_die(filename, pytype_single_args): sys.exit(1) -def read_config_file_or_die(filepath): +def read_config_file_or_die(filepath) -> FileConfig: """Read config from filepath or from setup.cfg.""" ret = FileConfig() diff --git a/pytype/tools/analyze_project/environment.py b/pytype/tools/analyze_project/environment.py index 2cd14aaf1..621cfcdc9 100644 --- a/pytype/tools/analyze_project/environment.py +++ b/pytype/tools/analyze_project/environment.py @@ -9,11 +9,11 @@ class PytdFileSystem(fs.ExtensionRemappingFileSystem): """File system that remaps .py file extensions to pytd.""" - def __init__(self, underlying): + def __init__(self, underlying) -> None: super().__init__(underlying, 'pytd') -def create_importlab_environment(conf, typeshed): +def create_importlab_environment(conf, typeshed) -> environment.Environment: """Create an importlab environment from the python version and path.""" python_version = utils.version_from_string(conf.python_version) path = fs.Path() diff --git a/pytype/tools/analyze_project/parse_args.py b/pytype/tools/analyze_project/parse_args.py index 3262c4287..7cd2147d1 100644 --- a/pytype/tools/analyze_project/parse_args.py +++ b/pytype/tools/analyze_project/parse_args.py @@ -25,7 +25,7 @@ def convert_string(s): class Parser: """Parser with additional functions for config file processing.""" - def __init__(self, parser, pytype_single_args): + def __init__(self, parser, pytype_single_args) -> None: """Initialize a parser. Args: @@ -36,7 +36,7 @@ def __init__(self, parser, pytype_single_args): self.pytype_single_args = pytype_single_args self._pytype_arg_map = pytype_config.args_map() - def create_initial_args(self, keys): + def create_initial_args(self, keys) -> argparse.Namespace: """Creates the initial set of args.""" return argparse.Namespace(**{k: None for k in keys}) @@ -47,7 +47,7 @@ def config_from_defaults(self): conf.populate_from(defaults) return conf - def clean_args(self, args, keys): + def clean_args(self, args, keys) -> None: """Clean None values out of the arg namespace. This lets us check for a config file arg based on whether the None default @@ -80,7 +80,7 @@ def parse_args(self, argv): self.postprocess(args) return args - def convert_strings(self, args: argparse.Namespace): + def convert_strings(self, args: argparse.Namespace) -> None: """Converts strings in an args namespace to values.""" for k in self.pytype_single_args: if hasattr(args, k): @@ -88,7 +88,7 @@ def convert_strings(self, args: argparse.Namespace): assert isinstance(v, str) setattr(args, k, convert_string(v)) - def postprocess(self, args: argparse.Namespace): + def postprocess(self, args: argparse.Namespace) -> None: """Postprocesses the subset of pytype_single_args that appear in args. Args: @@ -98,11 +98,11 @@ def postprocess(self, args: argparse.Namespace): opt_map = {k: self._pytype_arg_map[k].long_opt for k in names} pytype_config.Postprocessor(names, opt_map, args).process() - def error(self, message): + def error(self, message) -> None: self._parser.error(message) -def make_parser(): +def make_parser() -> Parser: """Make parser for command line args. Returns: @@ -163,7 +163,7 @@ def make_parser(): class _FlattenAction(argparse.Action): """Flattens a list of sets. Used by --exclude and inputs.""" - def __call__(self, parser, namespace, values, option_string=None): + def __call__(self, parser, namespace, values, option_string=None) -> None: items = getattr(namespace, self.dest, None) or set() # We want to keep items as None if values is empty, since that means the # argument was not passed on the command line. Note that an empty values @@ -175,7 +175,7 @@ def __call__(self, parser, namespace, values, option_string=None): items.update(v) -def _add_file_argument(parser, types, args, custom_kwargs=None): +def _add_file_argument(parser, types, args, custom_kwargs=None) -> None: """Add a file-configurable option to the parser. Args: diff --git a/pytype/tools/analyze_project/pytype_runner.py b/pytype/tools/analyze_project/pytype_runner.py index a2065c096..3b3043b50 100644 --- a/pytype/tools/analyze_project/pytype_runner.py +++ b/pytype/tools/analyze_project/pytype_runner.py @@ -8,13 +8,18 @@ import re import subprocess import sys +from typing import TypeVar, Union from pytype import file_utils from pytype import module_utils from pytype import utils +from pytype.module_utils import Module from pytype.platform_utils import path_utils from pytype.tools.analyze_project import config +_T0 = TypeVar('_T0') +_T1 = TypeVar('_T1') + # Generate a default pyi for builtin and system dependencies. DEFAULT_PYI = """ from typing import Any @@ -37,26 +42,31 @@ class Stage: FIRST_PASS_SUFFIX = '-1' -def _get_executable(binary, module=None): +def _get_executable( + binary: _T0, module: _T1 = None +) -> list[Union[str, _T0, _T1]]: """Get the path to the executable with the given name.""" if binary == 'pytype-single': custom_bin = path_utils.join('out', 'bin', 'pytype') if sys.argv[0] == custom_bin: # The Travis type-check step uses custom binaries in pytype/out/bin/. - return (([] if sys.platform != 'win32' else [sys.executable]) + [ + return ([] if sys.platform != 'win32' else [sys.executable]) + [ path_utils.join( path_utils.abspath(path_utils.dirname(custom_bin)), - 'pytype-single') - ]) + 'pytype-single', + ) + ] importable = importlib.util.find_spec(module or binary) if sys.executable is not None and importable: return [sys.executable, '-m', module or binary] else: return [binary] -PYTYPE_SINGLE = _get_executable('pytype-single', 'pytype.main') -def resolved_file_to_module(f): +PYTYPE_SINGLE: list[str] = _get_executable('pytype-single', 'pytype.main') + + +def resolved_file_to_module(f) -> Module: """Turn an importlab ResolvedFile into a pytype Module.""" full_path = f.path target = f.short_path @@ -69,7 +79,7 @@ def resolved_file_to_module(f): path=path, target=target, name=name, kind=f.__class__.__name__) -def _get_filenames(node): +def _get_filenames(node) -> tuple: if isinstance(node, str): return (node,) else: @@ -77,7 +87,7 @@ def _get_filenames(node): return tuple(sorted(node.nodes)) -def deps_from_import_graph(import_graph): +def deps_from_import_graph(import_graph) -> list[tuple[tuple, tuple]]: """Construct PytypeRunner args from an importlab.ImportGraph instance. Kept as a separate function so PytypeRunner can be tested independently of @@ -126,7 +136,7 @@ def split_files(filenames): return modules -def _is_type_stub(f): +def _is_type_stub(f) -> bool: _, ext = path_utils.splitext(f) return ext in ('.pyi', '.pytd') @@ -146,7 +156,7 @@ def _module_to_output_path(mod): return mod.name[0] + mod.name[1:].replace('.', path_utils.sep) -def escape_ninja_path(path: str): +def escape_ninja_path(path: str) -> str: """Returns the path with special characters escaped. Escape new line, space, colon, and dollar sign, for ninja @@ -161,7 +171,7 @@ def escape_ninja_path(path: str): return re.sub(r'(?P[\n :$])', r'$\g', path) -def get_imports_map(deps, module_to_imports_map, module_to_output): +def get_imports_map(deps, module_to_imports_map, module_to_output) -> dict: """Get a short path -> full path map for the given deps.""" imports_map = {} for m in deps: @@ -174,7 +184,7 @@ def get_imports_map(deps, module_to_imports_map, module_to_output): class PytypeRunner: """Runs pytype over an import graph.""" - def __init__(self, conf, sorted_sources): + def __init__(self, conf, sorted_sources) -> None: self.filenames = set(conf.inputs) # files to type-check # all source modules as a sequence of (module, direct_deps) self.sorted_sources = sorted_sources @@ -184,11 +194,14 @@ def __init__(self, conf, sorted_sources): self.imports_dir = path_utils.join(conf.output, 'imports') self.ninja_file = path_utils.join(conf.output, 'build.ninja') self.custom_options = [ - (k, getattr(conf, k)) for k in set(conf.__slots__) - set(config.ITEMS)] + (k, getattr(conf, k)) for k in set(conf.__slots__) - set(config.ITEMS) + ] self.keep_going = conf.keep_going self.jobs = conf.jobs - def set_custom_options(self, flags_with_values, binary_flags, report_errors): + def set_custom_options( + self, flags_with_values, binary_flags, report_errors + ) -> None: """Merge self.custom_options into flags_with_values and binary_flags.""" for dest, value in self.custom_options: if not report_errors and dest in config.REPORT_ERRORS_ITEMS: @@ -229,7 +242,7 @@ def get_pytype_command_for_ninja(self, report_errors): ['$in'] ) - def make_imports_dir(self): + def make_imports_dir(self) -> bool: try: file_utils.makedirs(self.imports_dir) except OSError: @@ -237,14 +250,14 @@ def make_imports_dir(self): return False return True - def write_default_pyi(self): + def write_default_pyi(self) -> str: """Write a default pyi file.""" output = path_utils.join(self.imports_dir, 'default.pyi') with open(output, 'w') as f: f.write(DEFAULT_PYI) return output - def write_imports(self, module_name, imports_map, suffix): + def write_imports(self, module_name, imports_map, suffix) -> str: """Write a .imports file.""" output = path_utils.join(self.imports_dir, module_name + '.imports' + suffix) @@ -253,7 +266,7 @@ def write_imports(self, module_name, imports_map, suffix): f.write('%s %s\n' % item) return output - def get_module_action(self, module): + def get_module_action(self, module) -> str: """Get the action for the given module. Args: @@ -308,7 +321,7 @@ def yield_sorted_modules( if action != Action.GENERATE_DEFAULT: yield module, action, deps, Stage.SECOND_PASS - def write_ninja_preamble(self): + def write_ninja_preamble(self) -> None: """Write out the pytype-single commands that the build will call.""" with open(self.ninja_file, 'w') as f: for action, report_errors in ((Action.INFER, False), @@ -323,7 +336,7 @@ def write_ninja_preamble(self): action=action, command=command) ) - def write_build_statement(self, module, action, deps, imports, suffix): + def write_build_statement(self, module, action, deps, imports, suffix) -> str: """Write a build statement for the given module. Args: @@ -356,7 +369,7 @@ def write_build_statement(self, module, action, deps, imports, suffix): module=module.name)) return output - def setup_build(self): + def setup_build(self) -> set[str]: """Write out the full build.ninja file. Returns: @@ -395,7 +408,7 @@ def setup_build(self): module, action, deps, imports, suffix) return files - def build(self): + def build(self) -> int: """Execute the build.ninja file.""" # -k N keep going until N jobs fail (0 means infinity) # -C DIR change to DIR before doing anything else diff --git a/pytype/tools/annotate_ast/annotate_ast.py b/pytype/tools/annotate_ast/annotate_ast.py index e23a5f9c8..2e2a22ce4 100644 --- a/pytype/tools/annotate_ast/annotate_ast.py +++ b/pytype/tools/annotate_ast/annotate_ast.py @@ -47,16 +47,16 @@ class AnnotateAstVisitor(traces.MatchAstVisitor): it is ast-module agnostic so that different AST implementations can be used. """ - def visit_Name(self, node): + def visit_Name(self, node) -> None: self._maybe_annotate(node) - def visit_Attribute(self, node): + def visit_Attribute(self, node) -> None: self._maybe_annotate(node) - def visit_FunctionDef(self, node): + def visit_FunctionDef(self, node) -> None: self._maybe_annotate(node) - def _maybe_annotate(self, node): + def _maybe_annotate(self, node) -> None: """Annotates a node.""" try: ops = self.match(node) @@ -66,7 +66,7 @@ def _maybe_annotate(self, node): unused_loc, entry = next(iter(ops), (None, None)) self._maybe_set_type(node, entry) - def _maybe_set_type(self, node, trace): + def _maybe_set_type(self, node, trace) -> None: """Sets type information on the node, if there is any to set.""" if not trace: return diff --git a/pytype/tools/arg_parser.py b/pytype/tools/arg_parser.py index d066b0dd3..ddcef3ddf 100644 --- a/pytype/tools/arg_parser.py +++ b/pytype/tools/arg_parser.py @@ -9,7 +9,7 @@ # Type alias _ArgDict = dict[str, Any] -Namespace = argparse.Namespace +Namespace: type[argparse.Namespace] = argparse.Namespace @dataclasses.dataclass @@ -32,7 +32,9 @@ def __init__(self, tool_args: Namespace, pytype_opts: pytype_config.Options): class Parser: """Parser that integrates tool and pytype-single args.""" - def __init__(self, parser, *, pytype_single_args=None, overrides=None): + def __init__( + self, parser, *, pytype_single_args=None, overrides=None + ) -> None: """Initialize a parser. Args: @@ -88,14 +90,14 @@ def process_parsed_args(self, tool_args: Namespace) -> ParsedArgs: pytype_opts = pytype_config.Options(pytype_args) return ParsedArgs(tool_args, pytype_opts) - def process(self, tool_args, pytype_args): + def process(self, tool_args, pytype_args) -> None: """Process raw pytype args before passing to config.Options.""" # Override in subclasses - def error(self, msg): + def error(self, msg) -> None: self._parser.error(msg) - def _ensure_valid_pytype_args(self, pytype_args: argparse.Namespace): + def _ensure_valid_pytype_args(self, pytype_args: argparse.Namespace) -> None: """Final adjustment of raw pytype args before constructing Options.""" # If we do not have an input file add a dummy one here; tools often need to # construct a config.Options without having an input file. diff --git a/pytype/tools/config.py b/pytype/tools/config.py index 763eb7cbc..ac71f2043 100644 --- a/pytype/tools/config.py +++ b/pytype/tools/config.py @@ -1,18 +1,18 @@ """Utilities for dealing with project configuration.""" import abc -from collections.abc import Iterable +from collections.abc import Generator, Iterable import configparser -from typing import TypeVar +from typing import Any, TypeVar from pytype.platform_utils import path_utils import toml -_CONFIG_FILENAMES = ('pyproject.toml', 'setup.cfg') +_CONFIG_FILENAMES: tuple[str, str] = ('pyproject.toml', 'setup.cfg') _ConfigSectionT = TypeVar('_ConfigSectionT', bound='ConfigSection') -def find_config_file(path): +def find_config_file(path) -> str | None: """Finds the first instance of a config file in a prefix of path.""" # Make sure path is a directory @@ -50,7 +50,7 @@ def items(self) -> Iterable[tuple[str, str]]: class TomlConfigSection(ConfigSection): """A section of a TOML config file.""" - def __init__(self, content): + def __init__(self, content) -> None: self._content = content @classmethod @@ -63,7 +63,7 @@ def create_from_file(cls, filepath, section): return cls(content['tool'][section]) return None - def items(self): + def items(self) -> Generator[tuple[Any, str], Any, None]: for k, v in self._content.items(): yield (k, ' '.join(str(e) for e in v) if isinstance(v, list) else str(v)) @@ -71,7 +71,7 @@ def items(self): class IniConfigSection(ConfigSection): """A section of an INI config file.""" - def __init__(self, parser, section): + def __init__(self, parser, section) -> None: self._parser = parser self._section = section diff --git a/pytype/tools/environment.py b/pytype/tools/environment.py index 9ede735fa..a8a898d6d 100644 --- a/pytype/tools/environment.py +++ b/pytype/tools/environment.py @@ -2,13 +2,15 @@ import logging import sys +from typing import Any from pytype.imports import typeshed +from pytype.imports.typeshed import Typeshed from pytype.platform_utils import path_utils from pytype.tools import runner -def check_pytype_or_die(): +def check_pytype_or_die() -> None: if not runner.can_run("pytype", "-h"): logging.critical( "Cannot run pytype. Check that it is installed and in your path" @@ -16,7 +18,7 @@ def check_pytype_or_die(): sys.exit(1) -def check_python_version(exe: list[str], required): +def check_python_version(exe: list[str], required) -> tuple[bool, Any]: """Check if exe is a python executable with the required version.""" try: # python --version outputs to stderr for earlier versions @@ -52,7 +54,7 @@ def check_python_exe_or_die(required) -> list[str]: sys.exit(1) -def initialize_typeshed_or_die(): +def initialize_typeshed_or_die() -> Typeshed: """Initialize a Typeshed object or die. Returns: @@ -65,7 +67,7 @@ def initialize_typeshed_or_die(): sys.exit(1) -def compute_pythonpath(filenames): +def compute_pythonpath(filenames) -> list: """Compute a list of dependency paths.""" paths = set() for f in filenames: diff --git a/pytype/tools/runner.py b/pytype/tools/runner.py index 6944e59b9..e47613036 100644 --- a/pytype/tools/runner.py +++ b/pytype/tools/runner.py @@ -10,7 +10,7 @@ class BinaryRun: ret, out, err = BinaryRun([exe, arg, ...]).communicate() """ - def __init__(self, args, dry_run=False): + def __init__(self, args, dry_run=False) -> None: self.args = args self.results = None @@ -31,7 +31,7 @@ def communicate(self): return self.results -def can_run(exe, *args): +def can_run(exe, *args) -> bool: """Check if running exe with args works.""" try: BinaryRun([exe] + list(args)).communicate() diff --git a/pytype/tools/tool_utils.py b/pytype/tools/tool_utils.py index edf0fa94c..de5b74772 100644 --- a/pytype/tools/tool_utils.py +++ b/pytype/tools/tool_utils.py @@ -6,7 +6,7 @@ from pytype import file_utils -def setup_logging_or_die(verbosity): +def setup_logging_or_die(verbosity) -> None: """Set the logging level or die.""" if verbosity == 0: level = logging.ERROR @@ -20,7 +20,7 @@ def setup_logging_or_die(verbosity): logging.basicConfig(level=level, format='%(levelname)s %(message)s') -def makedirs_or_die(path, message): +def makedirs_or_die(path, message) -> None: try: file_utils.makedirs(path) except OSError: diff --git a/pytype/tools/traces/source.py b/pytype/tools/traces/source.py index 1f8d473eb..d7fe213ad 100644 --- a/pytype/tools/traces/source.py +++ b/pytype/tools/traces/source.py @@ -2,10 +2,13 @@ import collections import dataclasses -from typing import Any, NamedTuple +from typing import Any, NamedTuple, TypeVar from pytype.pytd import pytd +_T1 = TypeVar("_T1") +_TAbstractTrace = TypeVar("_TAbstractTrace", bound="AbstractTrace") + class Location(NamedTuple): line: int @@ -19,13 +22,13 @@ class AbstractTrace: symbol: Any types: tuple[pytd.Node, ...] - def __new__(cls, op, symbol, types): + def __new__(cls: type[_TAbstractTrace], op, symbol, types) -> _TAbstractTrace: del op, symbol, types # unused if cls is AbstractTrace: raise TypeError("cannot instantiate AbstractTrace") return super().__new__(cls) - def __repr__(self): + def __repr__(self) -> str: return f"{self.op} : {self.symbol} <- {self.types}" @@ -39,7 +42,7 @@ class Code: only if an options object containing the filename was provided. """ - def __init__(self, src, raw_traces, trace_factory, filename): + def __init__(self, src, raw_traces, trace_factory, filename) -> None: """Initializer. Args: @@ -56,7 +59,7 @@ def __init__(self, src, raw_traces, trace_factory, filename): self._offsets = [] self._init_byte_offsets() - def _init_byte_offsets(self): + def _init_byte_offsets(self) -> None: offset = 0 for line in self._lines: self._offsets.append(offset) @@ -72,11 +75,11 @@ def line(self, n): """Gets the text at a line number.""" return self._lines[n - 1] - def get_closest_line_range(self, start, end): + def get_closest_line_range(self, start, end) -> range: """Gets all valid line numbers in the [start, end) line range.""" return range(start, min(end, len(self._lines) + 1)) - def find_first_text(self, start, end, text): + def find_first_text(self, start, end, text) -> Location | None: """Gets first location, if any, the string appears at in the line range.""" for l in self.get_closest_line_range(start, end): @@ -90,7 +93,7 @@ def find_first_text(self, start, end, text): return Location(l, col) return None - def next_non_comment_line(self, line): + def next_non_comment_line(self, line) -> int | None: """Gets the next non-comment line, if any, after the given line.""" for l in range(line + 1, len(self._lines) + 1): if self.line(l).lstrip().startswith("#"): @@ -98,7 +101,7 @@ def next_non_comment_line(self, line): return l return None - def display_traces(self): + def display_traces(self) -> None: """Prints the source file with traces for debugging.""" for line in sorted(self.traces): print("%d %s" % (line, self.line(line))) @@ -106,7 +109,9 @@ def display_traces(self): print(f" {trace}") print("-------------------") - def get_attr_location(self, name, location): + def get_attr_location( + self, name, location: _T1 + ) -> tuple[Location | _T1, int]: """Returns the location and span of the attribute in an attribute access. Args: @@ -158,7 +163,7 @@ def _get_multiline_location(self, location, n_lines, text): return None -def _collect_traces(raw_traces, trace_factory): +def _collect_traces(raw_traces, trace_factory) -> collections.defaultdict: """Postprocesses pytype's opcode traces.""" out = collections.defaultdict(list) for op, symbol, data in raw_traces: diff --git a/pytype/tools/traces/traces.py b/pytype/tools/traces/traces.py index a57e01c1e..02c9ccf19 100644 --- a/pytype/tools/traces/traces.py +++ b/pytype/tools/traces/traces.py @@ -1,7 +1,9 @@ """A library for accessing pytype's inferred local types.""" +from collections.abc import Generator import itertools import re +from typing import Any, TypeVar from pytype import analyze from pytype import config @@ -10,22 +12,23 @@ from pytype.pytd import pytd from pytype.pytd import pytd_utils from pytype.pytd import visitors - from pytype.tools.traces import source -_ATTR_OPS = frozenset(( +_T_SymbolMatcher = TypeVar("_T_SymbolMatcher", bound="_SymbolMatcher") + +_ATTR_OPS: frozenset[str] = frozenset(( "LOAD_ATTR", "LOAD_METHOD", "STORE_ATTR", )) -_BINMOD_OPS = frozenset(( +_BINMOD_OPS: frozenset[str] = frozenset(( "BINARY_MODULO", "BINARY_OP", "FORMAT_VALUE", )) -_CALL_OPS = frozenset(( +_CALL_OPS: frozenset[str] = frozenset(( "CALL", "CALL_FUNCTION", "CALL_FUNCTION_EX", @@ -35,15 +38,15 @@ "CALL_METHOD", )) -_LOAD_OPS = frozenset(( +_LOAD_OPS: frozenset[str] = frozenset(( "LOAD_DEREF", "LOAD_FAST", "LOAD_GLOBAL", "LOAD_NAME", )) -_LOAD_SUBSCR_METHODS = ("__getitem__", "__getslice__") -_LOAD_SUBSCR_OPS = frozenset(( +_LOAD_SUBSCR_METHODS: tuple[str, str] = ("__getitem__", "__getslice__") +_LOAD_SUBSCR_OPS: frozenset[str] = frozenset(( "BINARY_SLICE", "BINARY_SUBSCR", "SLICE_0", @@ -52,7 +55,7 @@ "SLICE_3", )) -_STORE_OPS = frozenset(( +_STORE_OPS: frozenset[str] = frozenset(( "STORE_DEREF", "STORE_FAST", "STORE_GLOBAL", @@ -108,21 +111,21 @@ class _SymbolMatcher: """ @classmethod - def from_one_match(cls, match): + def from_one_match(cls: type[_T_SymbolMatcher], match) -> _T_SymbolMatcher: return cls((match,)) @classmethod - def from_tuple(cls, matches): + def from_tuple(cls: type[_T_SymbolMatcher], matches) -> _T_SymbolMatcher: return cls(matches) @classmethod - def from_regex(cls, regex): + def from_regex(cls: type[_T_SymbolMatcher], regex) -> _T_SymbolMatcher: return cls((re.compile(regex),)) - def __init__(self, matches): + def __init__(self, matches) -> None: self._matches = matches - def match(self, symbol): + def match(self, symbol) -> bool: for match in self._matches: if isinstance(match, re.Pattern): if match.match(str(symbol)): @@ -139,7 +142,7 @@ class MatchAstVisitor(visitor.BaseVisitor): source: The source and trace information. """ - def __init__(self, src_code, *args, **kwargs): + def __init__(self, src_code, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.source = src_code # Needed for x[i] = @@ -147,17 +150,17 @@ def __init__(self, src_code, *args, **kwargs): # For tracking already matched traces self._matched = None - def enter_Assign(self, node): + def enter_Assign(self, node) -> None: if isinstance(node.targets[0], self._ast.Subscript): self._assign_subscr = node.targets[0].value - def leave_Assign(self, _): + def leave_Assign(self, _) -> None: self._assign_subscr = None - def enter_Module(self, _): + def enter_Module(self, _) -> None: self._matched = set() - def leave_Module(self, _): + def leave_Module(self, _) -> None: self._matched = None def match(self, node): @@ -211,10 +214,10 @@ def match_FunctionDef(self, node): for tr in self._get_traces(node.lineno, ["MAKE_FUNCTION"], symbol, 1) ] - def match_Import(self, node): + def match_Import(self, node) -> list: return list(self._match_import(node, is_from=False)) - def match_ImportFrom(self, node): + def match_ImportFrom(self, node) -> list: return list(self._match_import(node, is_from=True)) def match_Lambda(self, node): @@ -231,15 +234,25 @@ def match_Name(self, node): ops = _STORE_OPS else: return [] - return [(self._get_match_location(node), tr) - for tr in self._get_traces(lineno, ops, node.id, 1)] + return [ + (self._get_match_location(node), tr) + for tr in self._get_traces(lineno, ops, node.id, 1) + ] def match_Subscript(self, node): - return [(self._get_match_location(node), tr) for tr in self._get_traces( - node.lineno, _LOAD_SUBSCR_OPS, - _SymbolMatcher.from_tuple(_LOAD_SUBSCR_METHODS), 1)] + return [ + (self._get_match_location(node), tr) + for tr in self._get_traces( + node.lineno, + _LOAD_SUBSCR_OPS, + _SymbolMatcher.from_tuple(_LOAD_SUBSCR_METHODS), + 1, + ) + ] - def _get_traces(self, lineno, ops, symbol, maxmatch=-1, num_lines=1): + def _get_traces( + self, lineno, ops, symbol, maxmatch=-1, num_lines=1 + ) -> Generator[Any, Any, None]: """Yields matching traces. Args: @@ -290,10 +303,14 @@ def _get_node_name(self, node): return node.__class__.__name__ def _match_constant(self, node, value): - return [(self._get_match_location(node), tr) - for tr in self._get_traces(node.lineno, ["LOAD_CONST"], value, 1)] + return [ + (self._get_match_location(node), tr) + for tr in self._get_traces(node.lineno, ["LOAD_CONST"], value, 1) + ] - def _match_import(self, node, is_from): + def _match_import( + self, node, is_from + ) -> Generator[tuple[Any, Any], Any, None]: for alias in node.names: name = alias.asname if alias.asname else alias.name op = "STORE_NAME" if alias.asname or is_from else "IMPORT_NAME" @@ -303,11 +320,11 @@ def _match_import(self, node, is_from): class _LineNumberVisitor(visitor.BaseVisitor): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.line = 0 - def generic_visit(self, node): + def generic_visit(self, node) -> None: lineno = getattr(node, "lineno", 0) if lineno > self.line: self.line = lineno diff --git a/pytype/tools/xref/callgraph.py b/pytype/tools/xref/callgraph.py index d521ab9d8..32faeda18 100644 --- a/pytype/tools/xref/callgraph.py +++ b/pytype/tools/xref/callgraph.py @@ -1,12 +1,14 @@ """Trace function arguments, return values and calls to other functions.""" import dataclasses -from typing import Any +from typing import TypeVar, Union, Any from pytype.pytd import escape from pytype.pytd import pytd from pytype.pytd import pytd_utils +_T0 = TypeVar('_T0') + @dataclasses.dataclass class Attr: @@ -48,7 +50,7 @@ class Function: location: Any = dataclasses.field(default=None) -def unknown_to_any(typename): +def unknown_to_any(typename: _T0) -> Union[str, _T0]: if escape.UNKNOWN in typename: return 'typing.Any' return typename @@ -87,7 +89,7 @@ def get_function_params(pytd_fn): class FunctionMap: """Collect a map of function types and outbound callgraph edges.""" - def __init__(self, index): + def __init__(self, index) -> None: self.index = index self.fmap = self.init_from_index(index) @@ -106,7 +108,7 @@ def pytd_of_fn(self, f): # TODO(mdemello): log this return None - def init_from_index(self, index): + def init_from_index(self, index) -> dict[Any, Function]: """Initialize the function map.""" out = {} fn_defs = [(k, v) for k, v in index.defs.items() if v.typ == 'FunctionDef'] @@ -130,7 +132,7 @@ def init_from_index(self, index): out['module'] = Function(id='module') return out - def add_attr(self, ref, defn): + def add_attr(self, ref, defn) -> None: """Add an attr access within a function body.""" attrib = ref.name scope = ref.ref_scope @@ -156,7 +158,7 @@ def add_attr(self, ref, defn): else: fn.local_attrs.append(attr_access) - def add_param_def(self, ref, defn): + def add_param_def(self, ref, defn) -> None: """Add a function parameter definition.""" fn = self.fmap[ref.ref_scope] for param in fn.params: @@ -165,13 +167,13 @@ def add_param_def(self, ref, defn): param.type = unwrap_type(self.index.get_pytd(ref.data[0])) break - def add_link(self, ref, defn): + def add_link(self, ref, defn) -> None: if ref.typ == 'Attribute': self.add_attr(ref, defn) if defn.typ == 'Param': self.add_param_def(ref, defn) - def add_call(self, call): + def add_call(self, call) -> None: """Add a function call.""" scope = call.scope if scope not in self.fmap: diff --git a/pytype/tools/xref/debug.py b/pytype/tools/xref/debug.py index 3f186cfdb..2ab31e232 100644 --- a/pytype/tools/xref/debug.py +++ b/pytype/tools/xref/debug.py @@ -9,22 +9,22 @@ # We never care about protected access when writing debug code! -def format_loc(location): +def format_loc(location) -> str: # location is (line, column) fmt = "%d:%2d" % location return fmt.rjust(8) -def format_def_with_location(defn, loc): +def format_def_with_location(defn, loc) -> str: return f"{format_loc(loc)} | {defn.typ.ljust(15)} {defn.format()}" -def format_ref(ref): +def format_ref(ref) -> str: return (f"{format_loc(ref.location)} | {ref.typ.ljust(15)} " f"{ref.scope}.{ref.name}") -def format_call(call): +def format_call(call) -> str: return f"{format_loc(call.location)} | {'Call'.ljust(15)} {call.func}" @@ -32,7 +32,7 @@ def typename(node): return node.__class__.__name__ -def show_defs(index): +def show_defs(index) -> None: """Show definitions.""" for def_id in index.locs: defn = index.defs[def_id] @@ -42,7 +42,7 @@ def show_defs(index): print(" "*28 + str(defn.doc)) -def show_refs(index): +def show_refs(index) -> None: """Show references and associated definitions.""" indent = " : " for ref, defn in index.links: @@ -56,7 +56,7 @@ def show_refs(index): continue -def show_calls(index): +def show_calls(index) -> None: for call in index.calls: print(format_call(call)) @@ -75,7 +75,7 @@ def display_type(data): return name -def show_types(index): +def show_types(index) -> None: """Show inferred types.""" out = [] for def_id in index.locs: @@ -105,7 +105,7 @@ def show_types(index): print(f"{format_loc(location)} | {name.ljust(35)} {typ}") -def show_index(index): +def show_index(index) -> None: """Display output in human-readable format.""" def separator(): @@ -119,14 +119,14 @@ def separator(): separator() -def show_map(name, mapping): +def show_map(name, mapping) -> None: print("%s: {" % name) for k, v in mapping.items(): print(" ", k, v) print("}") -def show_kythe_spans(kythe_graph, src): +def show_kythe_spans(kythe_graph, src) -> None: """Show kythe spans.""" for entry in kythe_graph.entries: diff --git a/pytype/tools/xref/indexer.py b/pytype/tools/xref/indexer.py index 4352a668b..03bfa0a1a 100644 --- a/pytype/tools/xref/indexer.py +++ b/pytype/tools/xref/indexer.py @@ -24,9 +24,12 @@ from pytype.tools.xref import utils as xref_utils from pytype.tools.xref import node_utils +_T0 = TypeVar("_T0") +_TRemote = TypeVar("_TRemote", bound="Remote") + # A mapping of offsets between a node's start position and the symbol being # defined. e.g. in the declaration "class X" the X is at +6 from the start. -DEF_OFFSETS = { +DEF_OFFSETS: dict[str, int] = { "ClassDef": 6, # class X "FunctionDef": 4, # def f } @@ -50,19 +53,21 @@ def qualified_method(data): return [data.name] -def get_location(node): +def get_location(node) -> source.Location: # TODO(mdemello): The column offset for nodes like "class A" needs to be # adjusted to the start of the symbol. return source.Location(node.lineno, node.col_offset) -def get_end_location(node): +def get_end_location(node) -> source.Location: end_lineno = node.end_lineno end_col_offset = node.end_col_offset return source.Location(end_lineno, end_col_offset) -def match_opcodes(opcode_traces, lineno, op_match_list): +def match_opcodes( + opcode_traces, lineno, op_match_list +) -> list[tuple[Any, Any, Any]]: """Get all opcodes matching op_match_list on a given line. Args: @@ -81,7 +86,7 @@ def match_opcodes(opcode_traces, lineno, op_match_list): return out -def match_opcodes_multiline(opcode_traces, start, end, op_match_list): +def match_opcodes_multiline(opcode_traces, start, end, op_match_list) -> list: """Get all opcodes matching op_match_list in a range of lines.""" out = [] for line in range(start, end + 1): @@ -110,10 +115,10 @@ class PytypeValue: typ: Any id: str | None = dataclasses.field(default=None, init=False) - def __post_init__(self): + def __post_init__(self) -> None: self.id = self.module + "." + self.name - def format(self): + def format(self) -> str: return f"{self.id} {{ {self.module}.{self.typ} : {self.name} }}" @classmethod @@ -149,7 +154,7 @@ def from_data(cls, data): else: return [cls._from_data(x) for x in data] - def to_signature(self): + def to_signature(self) -> str: return self.module + "." + self.name @property @@ -162,10 +167,10 @@ class Module: """Module representation.""" name: str - def attr(self, attr_name): + def attr(self, attr_name) -> "Remote": return Remote(self.name, attr_name, resolved=True) - def submodule(self, attr_name): + def submodule(self, attr_name) -> "Remote": name = self.name + "." + attr_name return Remote(name, IMPORT_FILE_MARKER, resolved=True) @@ -221,20 +226,20 @@ class Definition: doc: str | None id: str | None = dataclasses.field(default=None, init=False) - def __post_init__(self): + def __post_init__(self) -> None: self.id = self.scope + "." + self.name - def format(self): + def format(self) -> str: return self.id - def to_signature(self): + def to_signature(self) -> str: return self.id def doc_signature(self): """Signature for the definition's docstring.""" return self.to_signature() + ".__doc__" - def node_kind(self): + def node_kind(self) -> str: # TODO(mdemello): Add more node types. if self.typ == "ClassDef": return "class" @@ -271,13 +276,13 @@ class Remote: id: str | None = dataclasses.field(default=None, init=False) typ: Any = dataclasses.field(default=None, init=False) - def __post_init__(self): + def __post_init__(self) -> None: self.id = self.module + "/module." + self.name - def attr(self, attr_name): + def attr(self: _TRemote, attr_name) -> _TRemote: return Remote(self.module, self.name + "." + attr_name, self.resolved) - def format(self): + def format(self) -> str: return self.id @property @@ -335,10 +340,10 @@ class Reference: location: source.Location id: str | None = dataclasses.field(default=None, init=False) - def __post_init__(self): + def __post_init__(self) -> None: self.id = self.scope + "." + self.name - def format(self): + def format(self) -> str: return self.id @@ -374,7 +379,7 @@ class Funcall: class Env: """A collection of namespaced symbols.""" - def __init__(self, ast, scope, parent, cls): + def __init__(self, ast, scope, parent, cls) -> None: """Initialize an environment. Arguments: @@ -411,7 +416,7 @@ def lookup(self, symbol): def __getitem__(self, symbol): return self.lookup(symbol)[1] - def __setitem__(self, symbol, value): + def __setitem__(self, symbol, value) -> None: self.env[symbol] = value def is_self_attr(self, node): @@ -459,7 +464,7 @@ class ScopedVisitor(ast_visitor.BaseVisitor): # TODO(mdemello): Is the two-level visitor hierarchy really buying us # anything by way of maintainability or readability? - def __init__(self, ast, module_name, **kwargs): + def __init__(self, ast, module_name, **kwargs) -> None: super().__init__(ast=ast, **kwargs) self.stack = [] self.class_ids = [] @@ -479,11 +484,11 @@ def get_id(self, node): else: raise Exception(f"Unexpected scope: {node!r}") # pylint: disable=broad-exception-raised - def iprint(self, x): + def iprint(self, x) -> None: """Print messages indented by scope level, for debugging.""" print(" " * len(self.stack), x) - def scope_id(self): + def scope_id(self) -> str: return ".".join(self.get_id(x) for x in self.stack) @property @@ -497,7 +502,7 @@ def current_env(self): current_scope = self.scope_id() return self.envs[current_scope] - def add_scope(self, node, is_class=False): + def add_scope(self, node, is_class=False) -> Env: if self.stack: parent = self.current_env else: @@ -513,23 +518,23 @@ def add_scope(self, node, is_class=False): self.envs[new_scope] = new_env return new_env - def enter_ClassDef(self, node): + def enter_ClassDef(self, node) -> None: new_env = self.add_scope(node, is_class=True) self.class_ids.append(self.scope_id()) # We need to set the env's cls to the new class, not the enclosing one. new_env.cls = self.current_class - def leave_ClassDef(self, _): + def leave_ClassDef(self, _) -> None: self.class_ids.pop() - def enter_FunctionDef(self, node): + def enter_FunctionDef(self, node) -> None: self.add_scope(node) - def enter_Module(self, node): + def enter_Module(self, node) -> None: super().enter_Module(node) # pytype: disable=attribute-error self.add_scope(node) - def leave(self, node): + def leave(self, node) -> None: """If the node has introduced a new scope, we need to pop it off.""" super().leave(node) if node == self.stack[-1]: @@ -539,7 +544,7 @@ def leave(self, node): class IndexVisitor(ScopedVisitor, traces.MatchAstVisitor): """Visitor that generates indexes.""" - def __init__(self, ast, src, module_name): + def __init__(self, ast, src, module_name) -> None: super().__init__(ast=ast, src_code=src, module_name=module_name) self.defs = {} self.locs = collections.defaultdict(list) @@ -585,7 +590,7 @@ def _get_node_name(self, node): return node return super()._get_node_name(node) - def make_def(self, node, **kwargs): + def make_def(self, node, **kwargs) -> tuple[Definition, DefLocation]: """Make a definition from a node.""" if isinstance(node, self._ast.Name): @@ -609,7 +614,7 @@ def make_def(self, node, **kwargs): defloc = DefLocation(defn.id, source.Location(line, col)) return (defn, defloc) - def make_ref(self, node, **kwargs): + def make_ref(self, node, **kwargs) -> Reference: """Make a reference from a node.""" assert "data" in kwargs # required kwarg @@ -662,33 +667,33 @@ def add_global_ref(self, node, **kwargs): kwargs.update({"ref_scope": "module"}) return self.add_local_ref(node, **kwargs) - def add_call(self, node, name, func, arg_varnames, return_type): + def add_call(self, node, name, func, arg_varnames, return_type) -> None: start = get_location(node) end = get_end_location(node) self.calls.append( Funcall(name, self.scope_id(), func, start, end, arg_varnames, return_type)) - def add_attr(self, node): + def add_attr(self, node) -> None: defn, _ = self.make_def(node) self.defs[defn.id] = defn env = self.envs[self.scope_id()] if env.is_self_attr(node): self.envs[self.scope_id()].setattr(node.attr, defn) - def _has_decorator(self, f, decorator): + def _has_decorator(self, f, decorator) -> bool: for d in f.decorator_list: if isinstance(d, self._ast.Name) and d.id == decorator: return True return False - def _record_childof(self, node, defn): + def _record_childof(self, node, defn) -> None: """Record a childof relationship for nested definitions.""" parent = self.scope_defn.get(self.scope_id()) if parent: self.childof.append((defn, parent)) - def enter_ClassDef(self, node): + def enter_ClassDef(self, node) -> None: class_name = node_utils.get_name(node, self._ast) last_line = max(node.lineno, node.body[0].lineno - 1) @@ -721,7 +726,7 @@ def enter_ClassDef(self, node): super().enter_ClassDef(node) self.scope_defn[self.scope_id()] = defn - def enter_FunctionDef(self, node): + def enter_FunctionDef(self, node) -> None: last_line = max(node.lineno, node.body[0].lineno - 1) ops = match_opcodes_multiline(self.traces, node.lineno, last_line, [ ("MAKE_FUNCTION", None), # py2 has no symbol, py3 has node.name @@ -803,17 +808,17 @@ def visit_Call(self, node): seen.add(f) return name - def visit_Assign(self, node): + def visit_Assign(self, node) -> None: for v in node.targets: if isinstance(v, self._ast.Attribute): self.add_attr(v) - def visit_AnnAssign(self, node): + def visit_AnnAssign(self, node) -> None: parent = self.scope_defn.get(self.scope_id()) if parent and parent.typ == "ClassDef": self.add_local_def(node, name=node.target) - def _add_attr_ref(self, node, node_str, trace): + def _add_attr_ref(self, node, node_str, trace) -> None: ref = self.add_local_ref( node, target=node.value, @@ -847,13 +852,13 @@ def visit_Attribute(self, node): def visit_Subscript(self, node): return node.value - def visit_DictComp(self, _node): + def visit_DictComp(self, _node) -> str: return "" - def visit_ListComp(self, _node): + def visit_ListComp(self, _node) -> str: return "" - def process_import(self, node): + def process_import(self, node) -> None: """Common code for Import and ImportFrom.""" for alias, (loc, trace) in zip(node.names, self.match(node)): @@ -921,17 +926,17 @@ def process_import(self, node): for mod in module_utils.get_all_prefixes(symbol): self.modules[self.scope_id() + "." + mod] = mod - def visit_Import(self, node): + def visit_Import(self, node) -> None: self.process_import(node) - def visit_ImportFrom(self, node): + def visit_ImportFrom(self, node) -> None: self.process_import(node) - def enter_Return(self, node): + def enter_Return(self, node) -> None: if isinstance(node.value, self._ast.Name): self.current_env.ret = _RETURNING_NAME - def leave_Return(self, node): + def leave_Return(self, node) -> None: if self.current_env.ret == _RETURNING_NAME: self.current_env.ret = None @@ -949,7 +954,7 @@ def __init__(self, src, loader, pytd_module, - module_name): + module_name) -> None: self.ast = ast self.source = src self.loader = loader @@ -975,7 +980,7 @@ def __init__(self, # Optionally preserve the pytype vm so we can access the types later self.vm = None - def index(self, code_ast): + def index(self, code_ast) -> None: """Index an AST corresponding to self.source.""" v = IndexVisitor(self.ast, self.source, self.module_name) @@ -992,7 +997,7 @@ def index(self, code_ast): self.function_params = v.function_params self.childof = v.childof - def get_def_offsets(self, defloc): + def get_def_offsets(self, defloc) -> tuple[Any, Any]: """Get the byte offsets for a definition.""" assert self.defs is not None @@ -1007,14 +1012,14 @@ def get_def_offsets(self, defloc): end = start + len(defn.name) return (start, end) - def get_doc_offsets(self, doc): + def get_doc_offsets(self, doc) -> tuple[Any, Any]: """Get the byte offsets for a docstring.""" start = self.source.get_offset(doc.location) end = start + doc.length return (start, end) - def finalize(self): + def finalize(self) -> None: """Postprocess the information gathered by the tree visitor.""" self.links = self._lookup_refs() @@ -1023,7 +1028,7 @@ def _get_attr_bounds(self, name, location): return self.get_anchor_bounds( *self.source.get_attr_location(name, location)) - def get_anchor_bounds(self, location, length): + def get_anchor_bounds(self, location, length) -> tuple[Any, Any]: """Generate byte offsets from a location and length.""" start = self.source.get_offset(location) @@ -1100,10 +1105,12 @@ def _get_mro(self, obj): else: return [] - def _is_pytype_module(self, obj): + def _is_pytype_module(self, obj) -> bool: return isinstance(obj, abstract.Module) - def _lookup_attribute_by_type(self, r, attr_name): + def _lookup_attribute_by_type( + self, r: _T0, attr_name + ) -> list[tuple[_T0, Any]]: """Look up an attribute using pytype annotations.""" links = [] @@ -1152,7 +1159,7 @@ def _lookup_attribute_by_type(self, r, attr_name): break return links - def _lookup_refs(self): + def _lookup_refs(self) -> list: """Look up references to generate links.""" links = [] @@ -1238,7 +1245,7 @@ def get_pytd(self, datum): visitors.RemoveUnknownClasses()) return self.loader.resolve_pytd(t, self.pytd_module) - def make_serializable(self): + def make_serializable(self) -> None: """Delete all data that cannot be pickled.""" for r in self.refs: r.target = None @@ -1283,7 +1290,7 @@ class PytypeError(Exception): class VmTrace(source.AbstractTrace): - def __repr__(self): + def __repr__(self) -> str: types_repr = tuple( t and [node_utils.typename(x) for x in t] for t in self.types) diff --git a/pytype/tools/xref/kythe.py b/pytype/tools/xref/kythe.py index 1a190944c..2a25af8f5 100644 --- a/pytype/tools/xref/kythe.py +++ b/pytype/tools/xref/kythe.py @@ -48,7 +48,7 @@ class Edge: class Kythe: """Store a list of kythe graph entries.""" - def __init__(self, source, args=None): + def __init__(self, source, args=None) -> None: if args: self.root = args.root self.corpus = args.corpus @@ -64,13 +64,13 @@ def __init__(self, source, args=None): self.file_vname = self._add_file(source.text) self._add_file_anchor() - def _encode(self, value): + def _encode(self, value) -> str: """Encode fact values as base64.""" value_bytes = bytes(value, "utf-8") encoded_bytes = base64.b64encode(value_bytes) return encoded_bytes.decode("utf-8") - def _add_file(self, file_contents): + def _add_file(self, file_contents) -> VName: # File vnames are special-cased to have an empty signature and lang. vname = VName( signature="", language="", path=self.path, root=self.root, @@ -79,7 +79,7 @@ def _add_file(self, file_contents): self.add_fact(vname, "text", file_contents) return vname - def _add_file_anchor(self): + def _add_file_anchor(self) -> None: # Add a special anchor for the first byte of a file, so we can link to it. anchor_vname = self.add_anchor(0, 0) mod_vname = self.vname(FILE_ANCHOR_SIGNATURE) @@ -87,14 +87,14 @@ def _add_file_anchor(self): self.add_edge(anchor_vname, "defines/implicit", mod_vname) self.add_edge(self.file_vname, "childof", mod_vname) - def _add_entry(self, entry): + def _add_entry(self, entry) -> None: """Make sure we don't have duplicate entries.""" if entry in self._seen_entries: return self._seen_entries.add(entry) self.entries.append(entry) - def vname(self, signature, filepath=None, root=None): + def vname(self, signature, filepath=None, root=None) -> VName: return VName( signature=signature, path=filepath or self.path, @@ -102,7 +102,7 @@ def vname(self, signature, filepath=None, root=None): root=root or self.root, corpus=self.corpus) - def stdlib_vname(self, signature, filepath=None): + def stdlib_vname(self, signature, filepath=None) -> VName: return VName( signature=signature, path=filepath or self.path, @@ -114,12 +114,12 @@ def anchor_vname(self, start, end): signature = "@%d:%d" % (start, end) return self.vname(signature) - def fact(self, source, fact_name, fact_value): + def fact(self, source, fact_name, fact_value) -> Fact: fact_name = "/kythe/" + fact_name fact_value = self._encode(fact_value) return Fact(source, fact_name, fact_value) - def edge(self, source, edge_name, target): + def edge(self, source, edge_name, target) -> Edge: edge_kind = "/kythe/edge/" + edge_name return Edge(source, edge_kind, target, "/") @@ -148,10 +148,10 @@ def add_anchor(self, start, end): def _process_deflocs(kythe: Kythe, index: indexer.Indexer): """Generate kythe edges for definitions.""" - for def_id in index.locs: - defn = index.defs[def_id] - for defloc in index.locs[def_id]: - defn = index.defs[defloc.def_id] + for def_id in index.locs: # pytype: disable=attribute-error + defn = index.defs[def_id] # pytype: disable=unsupported-operands + for defloc in index.locs[def_id]: # pytype: disable=unsupported-operands + defn = index.defs[defloc.def_id] # pytype: disable=unsupported-operands defn_vname = kythe.vname(defn.to_signature()) start, end = index.get_def_offsets(defloc) anchor_vname = kythe.add_anchor(start, end) @@ -171,7 +171,7 @@ def _process_deflocs(kythe: Kythe, index: indexer.Indexer): source=anchor_vname, target=defn_vname, edge_name="defines/binding") try: - alias = index.aliases[defn.id] + alias = index.aliases[defn.id] # pytype: disable=unsupported-operands except KeyError: pass else: @@ -195,7 +195,7 @@ def _process_deflocs(kythe: Kythe, index: indexer.Indexer): source=doc_vname, target=defn_vname, edge_name="documents") -def _process_params(kythe, index): +def _process_params(kythe, index) -> None: """Generate kythe edges for function parameters.""" for fp in index.function_params: @@ -273,7 +273,7 @@ def _process_childof(kythe: Kythe, index: indexer.Indexer): kythe.add_edge(source=source, target=target, edge_name="childof") -def _process_calls(kythe, index): +def _process_calls(kythe, index) -> None: """Generate kythe edges for function calls.""" # Checks if a function call corresponds to a resolved reference, and generates @@ -314,7 +314,7 @@ def _process_calls(kythe, index): assert False, ref # pytype: disable=name-error -def generate_graph(index, kythe_args): +def generate_graph(index, kythe_args) -> Kythe: kythe = Kythe(index.source, kythe_args) _process_deflocs(kythe, index) _process_params(kythe, index) diff --git a/pytype/tools/xref/output.py b/pytype/tools/xref/output.py index a11ae3fb9..a3223be12 100644 --- a/pytype/tools/xref/output.py +++ b/pytype/tools/xref/output.py @@ -2,6 +2,7 @@ import dataclasses import json +from typing import Any, Generator def unpack(obj): @@ -19,13 +20,13 @@ def unpack(obj): return obj -def json_kythe_graph(kythe_graph): +def json_kythe_graph(kythe_graph) -> Generator[str, Any, None]: """Generate kythe entries.""" for x in kythe_graph.entries: yield json.dumps(dataclasses.asdict(x)) -def output_kythe_graph(kythe_graph): +def output_kythe_graph(kythe_graph) -> None: for x in json_kythe_graph(kythe_graph): print(x) diff --git a/pytype/tools/xref/testdata/builtins.py b/pytype/tools/xref/testdata/builtins.py index f80e3f2b8..7ed4215ba 100644 --- a/pytype/tools/xref/testdata/builtins.py +++ b/pytype/tools/xref/testdata/builtins.py @@ -2,6 +2,6 @@ a = "hello" #- @split ref vname("module.str.split", _, _, "pytd:builtins", _) -b = a.split('.') +b: list[str] = a.split('.') #- @reverse ref vname("module.list.reverse", _, _, "pytd:builtins", _) -c = b.reverse() +c: list[str] = b.reverse() diff --git a/pytype/tracer_vm.py b/pytype/tracer_vm.py index 74a0fd660..538d4e945 100644 --- a/pytype/tracer_vm.py +++ b/pytype/tracer_vm.py @@ -6,7 +6,7 @@ import enum import logging import re -from typing import Any, Union +from typing import TypeVar, Any, Union import attrs from pytype import state as frame_state @@ -23,18 +23,25 @@ from pytype.pytd import visitors from pytype.typegraph import cfg -log = logging.getLogger(__name__) +_T0 = TypeVar("_T0") + +log: logging.Logger = logging.getLogger(__name__) # Most interpreter functions (including lambdas) need to be analyzed as # stand-alone functions. The exceptions are comprehensions and generators, which # have names like "" and "". -_SKIP_FUNCTION_RE = re.compile(r"<(?!lambda)\w+>$") +_SKIP_FUNCTION_RE: re.Pattern = re.compile(r"<(?!lambda)\w+>$") _InstanceCacheType = dict[ abstract.InterpreterClass, dict[Any, Union["_InitClassState", cfg.Variable]] ] -_METHOD_TYPES = abstract.INTERPRETER_FUNCTION_TYPES + ( +_METHOD_TYPES: tuple[ + type[abstract.BoundInterpreterFunction], + type[abstract.InterpreterFunction], + type[special_builtins.StaticMethodInstance], + type[special_builtins.ClassMethodInstance], +] = abstract.INTERPRETER_FUNCTION_TYPES + ( special_builtins.StaticMethodInstance, special_builtins.ClassMethodInstance, ) @@ -58,7 +65,7 @@ class _InitClassState(enum.Enum): class CallTracer(vm.VirtualMachine): """Virtual machine that records all function calls.""" - _CONSTRUCTORS = ("__new__", "__init__") + _CONSTRUCTORS: tuple[str, str] = ("__new__", "__init__") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -90,7 +97,9 @@ def create_kwargs(self, node): kwargs.merge_instance_type_parameter(node, abstract_utils.V, value_type) return kwargs.to_variable(node) - def create_method_arguments(self, node, method, use_defaults=False): + def create_method_arguments( + self, node: _T0, method, use_defaults=False + ) -> tuple[_T0, function.Args]: """Create arguments for the given method. Creates Unknown objects as arguments for the given method. Note that we @@ -134,7 +143,7 @@ def create_method_arguments(self, node, method, use_defaults=False): starstarargs=starstarargs, ) - def call_function_with_args(self, node, val, args): + def call_function_with_args(self, node, val, args) -> tuple[Any, Any]: """Call a function. Args: @@ -154,7 +163,7 @@ def call_function_with_args(self, node, val, args): def _call_function_in_frame( self, node, val, args, kwargs, starargs, starstarargs - ): + ) -> tuple[Any, Any]: # Try to get the function opcode with position information to construct the # frame, so that we have the data for any error messages raised before we # get to func.call(). @@ -223,7 +232,7 @@ def maybe_analyze_method(self, node, val, cls=None): node, _ = self.call_function_with_args(node, val, args) return node - def call_with_fake_args(self, node0, funcv): + def call_with_fake_args(self, node0, funcv) -> tuple[Any, Any]: """Attempt to call the given function with made-up arguments.""" # Note that this should only be used for functions that raised a # FailedFunctionCall error. This is not guaranteed to successfully call a @@ -265,7 +274,7 @@ def call_with_fake_args(self, node0, funcv): log.info("Unable to generate fake arguments for %s", funcv) return node, self.ctx.new_unsolvable(node) - def analyze_method_var(self, node0, name, var, cls): + def analyze_method_var(self, node0: _T0, name, var, cls) -> _T0: full_name = f"{cls.data.full_name}.{name}" if any( isinstance(v, abstract.INTERPRETER_FUNCTION_TYPES) for v in var.data @@ -294,7 +303,7 @@ def _bind_method(self, node, methodvar, instance_var): def _maybe_instantiate_binding_directly( self, node0, cls, container, instantiate_directly - ): + ) -> tuple[Any, Any, Any]: node1, new = cls.data.get_own_new(node0, cls) if not new: instantiate_directly = True @@ -310,7 +319,9 @@ def _maybe_instantiate_binding_directly( instance = None return node1, new, instance - def _instantiate_binding(self, node0, cls, container, instantiate_directly): + def _instantiate_binding( + self, node0, cls, container, instantiate_directly + ): """Instantiate a class binding.""" node1, new, maybe_instance = self._maybe_instantiate_binding_directly( node0, cls, container, instantiate_directly @@ -331,7 +342,9 @@ def _instantiate_binding(self, node0, cls, container, instantiate_directly): nodes.append(node4) return self.ctx.join_cfg_nodes(nodes), instance - def _instantiate_var(self, node, clsv, container, instantiate_directly): + def _instantiate_var( + self, node, clsv, container, instantiate_directly + ) -> tuple[Any, Any]: """Build an (dummy) instance from a class, for analyzing it.""" n = self.ctx.program.NewVariable() for cls in clsv.Bindings(node): @@ -341,7 +354,7 @@ def _instantiate_var(self, node, clsv, container, instantiate_directly): n.PasteVariable(var) return node, n - def _mark_maybe_missing_members(self, values): + def _mark_maybe_missing_members(self, values) -> None: """Set maybe_missing_members to True on these values and their type params. Args: @@ -362,7 +375,7 @@ def _mark_maybe_missing_members(self, values): def init_class_and_forward_node( self, node, cls, container=None, extra_key=None - ): + ) -> tuple[Any, Any]: """Instantiate a class, and also call __init__. Calling __init__ can be expensive, so this method caches its created @@ -423,7 +436,9 @@ def init_class_and_forward_node( def init_class(self, node, cls, container=None, extra_key=None): return self.init_class_and_forward_node(node, cls, container, extra_key)[-1] - def get_bound_method(self, node, obj, method_name, valself): + def get_bound_method( + self, node, obj, method_name, valself + ) -> tuple[Any, Any]: def bind(cur_node, m): return self._bind_method(cur_node, m, valself.AssignToNewVariable()) @@ -480,7 +495,7 @@ def call_init(self, node, instance): node = self._call_init_on_binding(node, b) return node - def reinitialize_if_initialized(self, node, instance): + def reinitialize_if_initialized(self, node, instance) -> None: if instance in self._initialized_instances: self._call_init_on_binding(node, instance.to_binding(node)) @@ -533,7 +548,7 @@ def analyze_class(self, node, val): node = self.analyze_method_var(node, name, b, val) return node - def analyze_function(self, node0, val): + def analyze_function(self, node0: _T0, val) -> _T0: if val.data.is_attribute_of_class: # We'll analyze this function as part of a class. log.info("Analyze functions: Skipping class method %s", val.data.name) @@ -543,7 +558,7 @@ def analyze_function(self, node0, val): node2.ConnectTo(node0) return node0 - def _should_analyze_as_interpreter_function(self, data): + def _should_analyze_as_interpreter_function(self, data) -> bool: # We record analyzed functions by opcode rather than function object. The # two ways of recording are equivalent except for closures, which are # re-generated when the variables they close over change, but we don't want @@ -629,7 +644,7 @@ def trace_functiondef(self, f): def trace_classdef(self, c): self._interpreter_classes.append(c) - def pytd_classes_for_unknowns(self): + def pytd_classes_for_unknowns(self) -> list[None]: classes = [] for name, val in self._unknowns.items(): log.info("Generating structural definition for unknown: %r", name) @@ -644,7 +659,7 @@ def _skip_definition_export(self, name, var): or self._is_future_feature(var) ) - def pytd_for_types(self, defs): + def pytd_for_types(self, defs) -> pytd.TypeDeclUnit: # If a variable is annotated, we'll always output that type. annotated_names = set() data = [] @@ -725,7 +740,9 @@ def pytd_for_types(self, defs): return pytd_utils.WrapTypeDeclUnit("inferred", data) @staticmethod - def _call_traces_to_function(call_traces, name_transform=lambda x: x): + def _call_traces_to_function( + call_traces, name_transform=lambda x: x + ) -> list[pytd.Function]: funcs = collections.defaultdict(pytd_utils.OrderedSet) def to_pytd_type(node, arg): @@ -779,7 +796,7 @@ def to_pytd_type(node, arg): ) return functions - def _is_typing_member(self, name, var): + def _is_typing_member(self, name, var) -> bool: for module_name in ("typing", "typing_extensions"): if module_name not in self.loaded_overlays: continue @@ -790,7 +807,7 @@ def _is_typing_member(self, name, var): return True return False - def _is_future_feature(self, var): + def _is_future_feature(self, var) -> bool: for v in var.data: if isinstance(v, abstract.Instance) and v.cls.module == "__future__": return True @@ -799,7 +816,7 @@ def _is_future_feature(self, var): def pytd_functions_for_call_traces(self): return self._call_traces_to_function(self._calls, escape.pack_partial) - def pytd_classes_for_call_traces(self): + def pytd_classes_for_call_traces(self) -> list[pytd.Class]: class_to_records = collections.defaultdict(list) for call_record in self._method_calls: args = call_record.positional_arguments diff --git a/pytype/typegraph/cfg_utils.py b/pytype/typegraph/cfg_utils.py index 5f3d56ead..9b6926bb6 100644 --- a/pytype/typegraph/cfg_utils.py +++ b/pytype/typegraph/cfg_utils.py @@ -1,9 +1,9 @@ """Utilities for working with the CFG.""" import collections -from collections.abc import Iterable, Sequence +from collections.abc import Generator, Iterable, Sequence import itertools -from typing import Protocol, TypeVar +from typing import Any, Protocol, TypeVar # Limit on how many argument combinations we allow before aborting. @@ -18,7 +18,7 @@ DEEP_VARIABLE_LIMIT = 1024 -def variable_product(variables): +def variable_product(variables) -> itertools.product: """Take the Cartesian product of a number of Variables. Args: @@ -31,7 +31,9 @@ def variable_product(variables): return itertools.product(*(v.bindings for v in variables)) -def _variable_product_items(variableitems, complexity_limit): +def _variable_product_items( + variableitems, complexity_limit +) -> Generator[list, Any, None]: """Take the Cartesian product of a list of (key, value) tuples. See variable_product_dict below. @@ -63,11 +65,11 @@ class TooComplexError(Exception): class ComplexityLimit: """A class that raises TooComplexError if we hit a limit.""" - def __init__(self, limit): + def __init__(self, limit) -> None: self.limit = limit self.count = 0 - def inc(self, add=1): + def inc(self, add=1) -> None: self.count += add if self.count >= self.limit: raise TooComplexError() @@ -104,7 +106,9 @@ def deep_variable_product(variables, limit=DEEP_VARIABLE_LIMIT): ) -def _deep_values_list_product(values_list, seen, complexity_limit): +def _deep_values_list_product( + values_list, seen, complexity_limit +) -> list[tuple]: """Take the deep Cartesian product of a list of list of Values.""" result = [] for row in itertools.product(*(values for values in values_list if values)): @@ -196,7 +200,9 @@ def merge_bindings(program, node, bindings): return v -def walk_binding(binding, keep_binding=lambda _: True): +def walk_binding( + binding, keep_binding=lambda _: True +) -> Generator[Any, Any, None]: """Helper function to walk a binding's origins. Args: @@ -332,7 +338,7 @@ def order_nodes(nodes: Sequence[_OrderableNode]) -> list[_OrderableNode]: return order -def topological_sort(nodes): +def topological_sort(nodes) -> Generator[Any, Any, None]: """Sort a list of nodes topologically. This will order the nodes so that any node that appears in the "incoming" diff --git a/pytype/typegraph/typegraph_serializer.py b/pytype/typegraph/typegraph_serializer.py index 2ac47b070..54385817b 100644 --- a/pytype/typegraph/typegraph_serializer.py +++ b/pytype/typegraph/typegraph_serializer.py @@ -96,7 +96,7 @@ class SerializedProgram: class TypegraphEncoder(json.JSONEncoder): """Implements the JSONEncoder behavior for typegraph objects.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._bindings: dict[int, cfg.Binding] = {} @@ -214,7 +214,18 @@ def default(self, o): return super().default(o) -_TYP_MAP = { +_TYP_MAP: dict[ + str, + type[ + SerializedBinding + | SerializedCFGNode + | SerializedOrigin + | SerializedProgram + | SerializedQuery + | SerializedQueryStep + | SerializedVariable + ], +] = { "Program": SerializedProgram, "CFGNode": SerializedCFGNode, "Variable": SerializedVariable, @@ -225,7 +236,17 @@ def default(self, o): } -def _decode(obj): +def _decode( + obj, +) -> ( + SerializedBinding + | SerializedCFGNode + | SerializedOrigin + | SerializedProgram + | SerializedQuery + | SerializedQueryStep + | SerializedVariable +): typ = obj.pop("_type") return _TYP_MAP[typ](**obj) diff --git a/pytype/types/base.py b/pytype/types/base.py index be66511f8..3bdb84ca3 100644 --- a/pytype/types/base.py +++ b/pytype/types/base.py @@ -25,4 +25,4 @@ def to_pytd_type_of_instance(self, *args, **kwargs) -> pytd.Type: # variables or expressions to abstract values. Variables are an internal # implementation detail that no external code should depend on; we define a # Variable type alias here simply to use in type signatures. -Variable = Any +Variable: Any = Any diff --git a/pytype/types/types.py b/pytype/types/types.py index 1e7f146c2..174d951ff 100644 --- a/pytype/types/types.py +++ b/pytype/types/types.py @@ -15,22 +15,23 @@ library. """ +from typing import Any from pytype.types import base from pytype.types import classes from pytype.types import functions from pytype.types import instances -BaseValue = base.BaseValue -Variable = base.Variable +BaseValue: type[base.BaseValue] = base.BaseValue +Variable: Any = base.Variable -Attribute = classes.Attribute -Class = classes.Class +Attribute: type[classes.Attribute] = classes.Attribute +Class: type[classes.Class] = classes.Class -Arg = functions.Arg -Args = functions.Args -Function = functions.Function -Signature = functions.Signature +Arg: type[functions.Arg] = functions.Arg +Args: type[functions.Args] = functions.Args +Function: type[functions.Function] = functions.Function +Signature: type[functions.Signature] = functions.Signature -Module = instances.Module -PythonConstant = instances.PythonConstant +Module: type[instances.Module] = instances.Module +PythonConstant: type[instances.PythonConstant] = instances.PythonConstant diff --git a/pytype/utils.py b/pytype/utils.py index 88d42f974..cad346247 100644 --- a/pytype/utils.py +++ b/pytype/utils.py @@ -7,6 +7,7 @@ import re import threading import traceback +from typing import Any, Callable import weakref from pytype.platform_utils import path_utils @@ -23,7 +24,7 @@ # We disable the check that keeps pytype from running on not-yet-supported # versions when we detect that a pytype test is executing, in order to be able # to test upcoming versions. -def _validate_python_version_upper_bound(): +def _validate_python_version_upper_bound() -> bool: for frame_summary in traceback.extract_stack(): head, tail = path_utils.split(frame_summary.filename) if "/pytype/" in head + "/" and ( @@ -40,17 +41,17 @@ class UsageError(Exception): """Raise this for top-level usage errors.""" -def format_version(python_version): +def format_version(python_version) -> str: """Format a version tuple into a dotted version string.""" return ".".join(str(x) for x in python_version) -def version_from_string(version_string): +def version_from_string(version_string) -> tuple: """Parse a version string like "3.7" into a tuple.""" return tuple(map(int, version_string.split("."))) -def validate_version(python_version): +def validate_version(python_version) -> None: """Raise an exception if the python version is unsupported.""" if len(python_version) != 2: # This is typically validated in the option parser, but check here too in @@ -86,7 +87,7 @@ def strip_prefix(string, prefix): return string -def maybe_truncate(s, length=30): +def maybe_truncate(s, length=30) -> str: """Truncate long strings (and append '...'), but leave short strings alone.""" s = str(s) if len(s) > length - 3: @@ -114,7 +115,7 @@ def pretty_conjunction(conjunction): return "(" + " & ".join(conjunction) + ")" -def pretty_dnf(dnf): +def pretty_dnf(dnf) -> str: """Pretty-print a disjunctive normal form (disjunction of conjunctions). E.g. [["a", "b"], ["c"]] -> "(a & b) | c". @@ -131,11 +132,11 @@ def pretty_dnf(dnf): return " | ".join(pretty_conjunction(c) for c in dnf) -def numeric_sort_key(s): +def numeric_sort_key(s) -> tuple: return tuple((int(e) if e.isdigit() else e) for e in re.split(r"(\d+)", s)) -def concat_tuples(tuples): +def concat_tuples(tuples) -> tuple: return tuple(itertools.chain.from_iterable(tuples)) @@ -157,7 +158,7 @@ def list_strip_prefix(l, prefix): return l[len(prefix) :] if list_startswith(l, prefix) else l -def invert_dict(d): +def invert_dict(d) -> collections.defaultdict: """Invert a dictionary. Converts a dictionary (mapping strings to lists of strings) to a dictionary @@ -177,7 +178,7 @@ def invert_dict(d): return inverted -def unique_list(xs): +def unique_list(xs) -> list: """Return a unique list from an iterable, preserving order.""" seen = set() out = [] @@ -206,7 +207,7 @@ class DynamicVar: in conjunction with a decorator. """ - def __init__(self): + def __init__(self) -> None: self._local = threading.local() def _values(self): @@ -238,10 +239,10 @@ class AnnotatingDecorator: lookup: maps functions to their attributes. """ - def __init__(self): + def __init__(self) -> None: self.lookup = {} - def __call__(self, value): + def __call__(self, value) -> Callable[[Any], Any]: def decorate(f): self.lookup[f.__name__] = value return f @@ -253,7 +254,7 @@ class ContextWeakrefMixin: __slots__ = ["ctx_weakref"] - def __init__(self, ctx): + def __init__(self, ctx) -> None: self.ctx_weakref = weakref.ref(ctx) @property diff --git a/pytype/vm.py b/pytype/vm.py index baa8d752a..16384c61a 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -18,7 +18,7 @@ import itertools import logging import re -from typing import Any +from typing import TypeVar, Union, Any from pycnite import marshal as pyc_marshal from pytype import block_environment @@ -49,8 +49,10 @@ from pytype.typegraph import cfg from pytype.typegraph import cfg_utils +_T0 = TypeVar("_T0") -log = logging.getLogger(__name__) + +log: logging.Logger = logging.getLogger(__name__) @dataclasses.dataclass(eq=True, frozen=True) @@ -71,7 +73,7 @@ def is_annotate(self): return self.op == self.Op.ANNOTATE -_opcode_counter = metrics.MapCounter("vm_opcode") +_opcode_counter: metrics.MapCounter = metrics.MapCounter("vm_opcode") class _UninitializedBehavior(enum.Enum): @@ -90,7 +92,7 @@ class VirtualMachine: class VirtualMachineRecursionError(Exception): pass - def __init__(self, ctx): + def __init__(self, ctx) -> None: """Construct a TypegraphVirtualMachine.""" self.ctx = ctx # context.Context # The call stack of frames. @@ -174,7 +176,9 @@ def _suppress_opcode_tracing(self): self._trace_opcodes = old_trace_opcodes @contextlib.contextmanager - def generate_late_annotations(self, stack): + def generate_late_annotations( + self, stack + ): old_late_annotations_stack = self._late_annotations_stack self._late_annotations_stack = stack try: @@ -182,7 +186,7 @@ def generate_late_annotations(self, stack): finally: self._late_annotations_stack = old_late_annotations_stack - def trace_opcode(self, op, symbol, val): + def trace_opcode(self, op, symbol, val) -> None: """Record trace data for other tools to use.""" if not self._trace_opcodes: return @@ -210,7 +214,7 @@ def remaining_depth(self): assert self._maximum_depth is not None return self._maximum_depth - len(self.frames) - def is_at_maximum_depth(self): + def is_at_maximum_depth(self) -> bool: return len(self.frames) > self._maximum_depth def _is_match_case_op(self, op): @@ -237,7 +241,7 @@ def _is_match_case_op(self, op): ) return is_match or is_cmp_match or is_default_match or is_none_match - def _handle_match_case(self, state, op): + def _handle_match_case(self, state: _T0, op) -> _T0: """Track type narrowing and default cases in a match statement.""" if not self._is_match_case_op(op): return state @@ -307,7 +311,9 @@ def run_instruction( self.frame.current_opcode = None return state - def _run_frame_blocks(self, frame, node, annotated_locals): + def _run_frame_blocks( + self, frame, node, annotated_locals + ) -> tuple[bool, list]: """Runs a frame's code blocks.""" frame.states[frame.f_code.get_first_opcode()] = frame_state.FrameState.init( node, self.ctx @@ -364,7 +370,7 @@ def _run_frame_blocks(self, frame, node, annotated_locals): vm_utils.update_excluded_types(node, self.ctx) return can_return, return_nodes - def run_frame(self, frame, node, annotated_locals=None): + def run_frame(self, frame, node, annotated_locals=None) -> tuple[Any, Any]: """Run a frame (typically belonging to a method).""" self.push_frame(frame) try: @@ -390,11 +396,11 @@ def run_frame(self, frame, node, annotated_locals=None): ) return node, frame.return_variable - def push_frame(self, frame): + def push_frame(self, frame) -> None: self.frames.append(frame) self.frame = frame - def pop_frame(self, frame): + def pop_frame(self, frame) -> None: popped_frame = self.frames.pop() assert popped_frame == frame if self.frames: @@ -420,7 +426,7 @@ def make_frame( func=None, first_arg=None, substs=(), - ): + ) -> frame_state.Frame: """Create a new frame object, using the given args, globals and locals.""" if any(code is f.f_code for f in self.frames): log.info("Detected recursion in %s", code.name or code.filename) @@ -455,7 +461,7 @@ def make_frame( substs, ) - def simple_stack(self, opcode=None): + def simple_stack(self, opcode=None) -> tuple[frame_state.SimpleFrame, ...]: """Get a stack of simple frames. Args: @@ -511,7 +517,7 @@ def pop_abstract_exception(self, state): state, _ = state.popn(3) return state - def resume_frame(self, node, frame): + def resume_frame(self, node, frame) -> tuple[Any, Any]: frame.f_back = self.frame log.info("resume_frame: %r", frame) node, val = self.run_frame(frame, node) @@ -534,7 +540,9 @@ def compile_src( self.block_graph = block_graph return code - def run_bytecode(self, node, code, f_globals=None, f_locals=None): + def run_bytecode( + self, node, code, f_globals=None, f_locals=None + ) -> tuple[Any, Any, Any, Any]: """Run the given bytecode.""" if f_globals is not None: assert f_locals @@ -560,7 +568,7 @@ def run_bytecode(self, node, code, f_globals=None, f_locals=None): node, return_var = self.run_frame(frame, node) return node, frame.f_globals, frame.f_locals, return_var - def run_program(self, src, filename, maximum_depth): + def run_program(self, src, filename, maximum_depth) -> tuple[Any, Any]: """Run the code and return the CFG nodes. Args: @@ -588,9 +596,9 @@ def run_program(self, src, filename, maximum_depth): # but there isn't a better way to wire both pieces together. self.ctx.errorlog.set_error_filter(director.filter_error) self._director = director - self.ctx.options.set_feature_flags(director.features) + self.ctx.options.set_feature_flags(director.features) # pytype: disable=attribute-error self._branch_tracker = pattern_matching.BranchTracker( - director.matches, self.ctx + director.matches, self.ctx # pytype: disable=attribute-error ) code = process_blocks.merge_annotations( code, self._director.annotations, self._director.param_annotations @@ -627,13 +635,13 @@ def run_program(self, src, filename, maximum_depth): log.info("Final node: <%d>%s", node.id, node.name) return node, f_globals.members - def flatten_late_annotation(self, node, annot, f_globals): + def flatten_late_annotation(self, node, annot, f_globals) -> None: flattened_expr = annot.flatten_expr() if flattened_expr != annot.expr: annot.expr = flattened_expr f_globals.members[flattened_expr] = annot.to_variable(node) - def set_var_name(self, var, name): + def set_var_name(self, var, name) -> None: self._var_names[var.id] = name def get_var_name(self, var): @@ -653,7 +661,7 @@ def get_var_name(self, var): names = {self._var_names.get(s.variable.id) for s in sources} return next(iter(names)) if len(names) == 1 else None - def get_all_named_vars(self): + def get_all_named_vars(self) -> dict[None, None]: # Make a shallow copy of the dict so callers aren't touching internal data. return dict(self._var_names) @@ -685,11 +693,13 @@ def trace_functiondef(self, *args): def trace_classdef(self, *args): return NotImplemented - def call_init(self, node, unused_instance): + def call_init(self, node: _T0, unused_instance) -> _T0: # This dummy implementation is overwritten in tracer_vm.py. return node - def init_class(self, node, cls, container=None, extra_key=None): + def init_class( + self, node, cls, container=None, extra_key=None + ): # This dummy implementation is overwritten in tracer_vm.py. del cls, container, extra_key return NotImplemented @@ -729,7 +739,7 @@ def call_function_with_state( state = self._check_test_assert(state, funcv, posargs) return state, ret - def call_with_fake_args(self, node0, funcv): + def call_with_fake_args(self, node0: _T0, funcv) -> tuple[_T0, Any]: """Attempt to call the given function with made-up arguments.""" return node0, self.ctx.new_unsolvable(node0) @@ -813,7 +823,7 @@ def call_function_from_stack_311(self, state, num): state, func, posargs, namedargs, starargs, starstarargs ) - def get_globals_dict(self): + def get_globals_dict(self) -> abstract.LazyConcreteDict: """Get a real python dict of the globals.""" return self.frame.f_globals @@ -857,7 +867,9 @@ def load_from( self.set_var_name(ret, name) return state, ret - def load_local(self, state, name): + def load_local( + self, state, name + ) -> tuple[frame_state.FrameState, cfg.Variable]: """Called when a local is loaded onto the stack. Uses the name to retrieve the value from the current locals(). @@ -886,7 +898,9 @@ def load_local(self, state, name): return self.load_from(state, self.frame.f_locals, name) - def load_global(self, state, name): + def load_global( + self, state, name + ) -> tuple[frame_state.FrameState, cfg.Variable]: # The concrete value of typing.TYPE_CHECKING should be preserved; otherwise, # concrete values are converted to abstract instances of their types, as we # generally can't assume that globals are constant. @@ -905,7 +919,9 @@ def load_special_builtin(self, name): else: return self.ctx.special_builtins.get(name) - def load_builtin(self, state, name): + def load_builtin( + self, state: _T0, name + ) -> tuple[Union[frame_state.FrameState, _T0], Any]: if name == "__undefined__": # For values that don't exist. (Unlike None, which is a valid object) return state, self.ctx.convert.empty.to_variable(self.ctx.root_node) @@ -931,7 +947,9 @@ def _load_annotation(self, node, name, store): return ret raise KeyError(name) - def _record_local(self, node, op, name, typ, orig_val=None, final=None): + def _record_local( + self, node, op, name, typ, orig_val=None, final=None + ) -> None: """Record a type annotation on a local variable. This method records three types of local operations: @@ -967,7 +985,7 @@ def _record_local(self, node, op, name, typ, orig_val=None, final=None): def _update_annotations_dict( self, node, op, name, typ, orig_val, annotations_dict, final=None - ): + ) -> None: if name in annotations_dict: annotations_dict[name].update(node, op, typ, orig_val) else: @@ -1137,11 +1155,11 @@ def _var_is_none(self, v: cfg.Variable) -> bool: self._data_is_none(b.data) for b in v.bindings ) - def _delete_item(self, state, obj, arg): + def _delete_item(self, state, obj, arg) -> frame_state.FrameState: state, _ = self._call(state, obj, "__delitem__", (arg,)) return state - def load_attr(self, state, obj, attr): + def load_attr(self, state, obj, attr) -> tuple[Any, Any]: """Try loading an attribute, and report errors.""" node, result, errors = self._retrieve_attr(state.node, obj, attr) self._attribute_error_detection(state, attr, errors) @@ -1149,7 +1167,7 @@ def load_attr(self, state, obj, attr): result = self.ctx.new_unsolvable(node) return state.change_cfg_node(node), result - def _attribute_error_detection(self, state, attr, errors): + def _attribute_error_detection(self, state, attr, errors) -> None: if not self.ctx.options.report_errors: return for error in errors: @@ -1161,7 +1179,7 @@ def _attribute_error_detection(self, state, attr, errors): def _filter_none_and_paste_bindings( self, node, bindings, var, discard_concrete_values=False - ): + ) -> None: """Paste the bindings into var, filtering out false positives on None.""" for b in bindings: if self._has_strict_none_origins(b): @@ -1182,7 +1200,7 @@ def _filter_none_and_paste_bindings( # TODO(rechen): Remove once --strict-none-binding is fully enabled. var.AddBinding(self.ctx.convert.unsolvable, [b], node) - def _has_strict_none_origins(self, binding): + def _has_strict_none_origins(self, binding) -> bool: """Whether the binding has any possible origins, with None filtering. Determines whether the binding has any possibly visible origins at the @@ -1221,7 +1239,7 @@ def _has_strict_none_origins(self, binding): has_any_none_origin = True return not has_any_none_origin - def load_attr_noerror(self, state, obj, attr): + def load_attr_noerror(self, state, obj, attr) -> tuple[Any, Any]: """Try loading an attribute, ignore errors.""" node, result, _ = self._retrieve_attr(state.node, obj, attr) return state.change_cfg_node(node), result @@ -1251,7 +1269,7 @@ def store_attr( else: return state - def del_attr(self, state, obj, attr): + def del_attr(self, state: _T0, obj, attr) -> _T0: """Delete an attribute.""" log.info( "Attribute removal does not do anything in the abstract interpreter" @@ -1402,7 +1420,7 @@ def unary_operator(self, state, name): state = state.push(result) return state - def _is_classmethod_cls_arg(self, var): + def _is_classmethod_cls_arg(self, var) -> bool: """True if var is the first arg of a class method in the current frame.""" if not (self.frame.func and self.frame.first_arg): return False @@ -1441,7 +1459,7 @@ def _get_aiter(self, state, obj): else: return state, self.ctx.new_unsolvable(state.node) - def _get_iter(self, state, seq, report_errors=True): + def _get_iter(self, state, seq, report_errors=True) -> tuple[Any, Any]: """Get an iterator from a sequence.""" # TODO(b/201603421): We should iterate through seq's bindings, in order to # fetch the attribute on the sequence's class, but two problems prevent us @@ -1471,7 +1489,7 @@ def _get_iter(self, state, seq, report_errors=True): self.ctx.errorlog.attribute_error(self.frames, m, "__iter__") return state, itr - def byte_NOP(self, state, op): + def byte_NOP(self, state: _T0, op) -> _T0: return state def byte_UNARY_NOT(self, state, op): @@ -1821,7 +1839,7 @@ def byte_LOAD_CLASSDEREF(self, state, op): except KeyError: return vm_utils.load_closure_cell(state, op, False, self.ctx) - def _cmp_rel(self, state, op_name, x, y): + def _cmp_rel(self, state, op_name, x, y) -> tuple[Any, Any]: """Implementation of relational operators CMP_(LT|LE|EQ|NE|GE|GT). Args: @@ -1907,7 +1925,7 @@ def _coerce_to_bool(self, var, true_val=True): bool_var.PasteBindingWithNewData(b, self.ctx.convert.bool_values[const]) return bool_var - def _cmp_in(self, state, item, seq, true_val=True): + def _cmp_in(self, state, item, seq, true_val=True) -> tuple[Any, Any]: """Implementation of CMP_IN/CMP_NOT_IN.""" state, has_contains = self.load_attr_noerror(state, seq, "__contains__") if has_contains: @@ -1929,11 +1947,11 @@ def _cmp_in(self, state, item, seq, true_val=True): ret = self.ctx.convert.build_bool(state.node) return state, ret - def _cmp_is_always_supported(self, op_arg): + def _cmp_is_always_supported(self, op_arg) -> bool: """Checks if the comparison should always succeed.""" return op_arg in slots.CMP_ALWAYS_SUPPORTED - def _instantiate_exception(self, node, exc_type): + def _instantiate_exception(self, node, exc_type) -> tuple[Any, list]: """Instantiate an exception type. Args: @@ -2097,7 +2115,9 @@ def byte_LOAD_ATTR(self, state, op): self.trace_opcode(op, name, (obj, val)) return state - def _get_type_of_attr_to_store(self, node, op, obj, name): + def _get_type_of_attr_to_store( + self, node, op, obj, name + ) -> tuple[Any, Any, bool]: """Grabs the __annotations__ dict, if any, with the attribute type.""" check_type = True annotations_dict = None @@ -2169,7 +2189,7 @@ def _get_type_of_attr_to_store(self, node, op, obj, name): # pylint: enable=unsupported-assignment-operation,unsupported-membership-test return node, annotations_dict, check_type - def byte_STORE_ATTR(self, state, op): + def byte_STORE_ATTR(self, state, op) -> frame_state.FrameState: """Store an attribute.""" name = op.argval state, (val, obj) = state.popn(2) @@ -2191,11 +2211,11 @@ def byte_DELETE_ATTR(self, state, op): state, obj = state.pop() return self.del_attr(state, obj, name) - def store_subscr(self, state, obj, key, val): + def store_subscr(self, state, obj, key, val) -> frame_state.FrameState: state, _ = self._call(state, obj, "__setitem__", (key, val)) return state - def _record_annotation_dict_store(self, state, obj, subscr, val, op): + def _record_annotation_dict_store(self, state, obj, subscr, val, op) -> None: """Record a store_subscr to an __annotations__ dict.""" try: name = abstract_utils.get_atomic_python_constant(subscr, str) @@ -2377,7 +2397,7 @@ def byte_BUILD_SLICE(self, state, op): else: # pragma: no cover raise VirtualMachineError(f"Strange BUILD_SLICE count: {op.arg!r}") - def byte_LIST_APPEND(self, state, op): + def byte_LIST_APPEND(self, state, op) -> frame_state.FrameState: # Used by the compiler e.g. for [x for x in ...] count = op.arg state, val = state.pop() @@ -2437,7 +2457,7 @@ def byte_LIST_EXTEND(self, state, op): ) return state - def byte_SET_ADD(self, state, op): + def byte_SET_ADD(self, state, op) -> frame_state.FrameState: # Used by the compiler e.g. for {x for x in ...} count = op.arg state, val = state.pop() @@ -2445,13 +2465,13 @@ def byte_SET_ADD(self, state, op): state, _ = self._call(state, the_set, "add", (val,)) return state - def byte_SET_UPDATE(self, state, op): + def byte_SET_UPDATE(self, state, op) -> frame_state.FrameState: state, update = state.pop() target = state.peek(op.arg) state, _ = self._call(state, target, "update", (update,)) return state - def byte_MAP_ADD(self, state, op): + def byte_MAP_ADD(self, state, op) -> frame_state.FrameState: """Implements the MAP_ADD opcode.""" # Used by the compiler e.g. for {x, y for x, y in ...} count = op.arg @@ -2516,11 +2536,11 @@ def byte_POP_JUMP_IF_FALSE(self, state, op): state, op, self.ctx, jump_if_val=False, pop=vm_utils.PopBehavior.ALWAYS ) - def byte_JUMP_FORWARD(self, state, op): + def byte_JUMP_FORWARD(self, state: _T0, op) -> _T0: self.store_jump(op.target, state.forward_cfg_node("JumpForward")) return state - def byte_JUMP_ABSOLUTE(self, state, op): + def byte_JUMP_ABSOLUTE(self, state: _T0, op) -> _T0: self.store_jump(op.target, state.forward_cfg_node("JumpAbsolute")) return state @@ -2563,7 +2583,7 @@ def byte_GET_ITER(self, state, op): # Push the iterator onto the stack and return. return state.push(itr) - def store_jump(self, target, state): + def store_jump(self, target, state) -> None: """Stores a jump to the target opcode.""" assert target assert self.frame is not None @@ -2601,14 +2621,14 @@ def _revert_state_to(self, state, name): state = state.pop_and_discard() return state - def byte_BREAK_LOOP(self, state, op): + def byte_BREAK_LOOP(self, state: _T0, op) -> _T0: new_state, block = self._revert_state_to(state, "loop").pop_block() while block.level < len(new_state.data_stack): new_state = new_state.pop_and_discard() self.store_jump(op.block_target, new_state) return state - def byte_CONTINUE_LOOP(self, state, op): + def byte_CONTINUE_LOOP(self, state: _T0, op) -> _T0: new_state = self._revert_state_to(state, "loop") self.store_jump(op.target, new_state) return state @@ -2625,7 +2645,7 @@ def _setup_except(self, state, op): self.store_jump(op.target, jump_state) return vm_utils.push_block(state, "setup-except", index=op.index) - def is_setup_except(self, op): + def is_setup_except(self, op) -> bool: """Check whether op is setting up an except block.""" if isinstance(op, opcodes.SETUP_FINALLY): for i, block in enumerate(self.frame.f_code.order): @@ -2657,7 +2677,7 @@ def byte_SETUP_FINALLY(self, state, op): def byte_BEGIN_FINALLY(self, state, op): return state.push(self.ctx.convert.build_none(state.node)) - def byte_CALL_FINALLY(self, state, op): + def byte_CALL_FINALLY(self, state: _T0, op) -> _T0: return state def byte_END_ASYNC_FOR(self, state, op): @@ -2753,7 +2773,7 @@ def byte_WITH_CLEANUP_FINISH(self, state, op): state = state.push(self.ctx.convert.build_none(state.node)) return state - def _convert_kw_defaults(self, values): + def _convert_kw_defaults(self, values) -> dict: kw_defaults = {} for i in range(0, len(values), 2): key_var, value = values[i : i + 2] @@ -2761,7 +2781,9 @@ def _convert_kw_defaults(self, values): kw_defaults[key] = value return kw_defaults - def _get_extra_closure_args(self, state, arg): + def _get_extra_closure_args( + self, state, arg + ) -> tuple[Any, Any, Any, Any, None]: """Get closure annotations and defaults from the stack.""" num_pos_defaults = arg & 0xFF num_kw_defaults = (arg >> 8) & 0xFF @@ -2775,7 +2797,9 @@ def _get_extra_closure_args(self, state, arg): ) return state, pos_defaults, kw_defaults, annot, free_vars - def _get_extra_function_args(self, state, arg): + def _get_extra_function_args( + self, state, arg + ) -> tuple[Any, Any, Any, Any, Any]: """Get function annotations and defaults from the stack.""" free_vars = None pos_defaults = () @@ -3008,10 +3032,10 @@ def byte_END_FINALLY(self, state, op): # no handler matched, hence Python re-raises the exception. return state.set_why("reraise") - def _check_return(self, node, actual, formal): + def _check_return(self, node, actual, formal) -> bool: return False # overwritten in tracer_vm.py - def _set_frame_return(self, node, frame, var): + def _set_frame_return(self, node, frame, var) -> None: if frame.allowed_returns is not None: retvar = self.init_class(node, frame.allowed_returns) else: @@ -3075,7 +3099,7 @@ def byte_SETUP_ANNOTATIONS(self, state, op): ).to_variable(state.node) return self.store_local(state, "__annotations__", annotations) - def _record_annotation(self, node, op, name, typ): + def _record_annotation(self, node, op, name, typ) -> None: # Annotations in self._director are handled by _apply_annotation. if self.current_line not in self._director.annotations: self._record_local(node, op, name, typ) @@ -3298,7 +3322,7 @@ def byte_YIELD_FROM(self, state, op): ret_var = self._get_generator_return(state.node, generator) return state.push(ret_var) - def _load_method(self, state, self_obj, name): + def _load_method(self, state, self_obj, name) -> tuple[Any, Any]: """Loads and pushes a method on the stack. Args: @@ -3520,7 +3544,7 @@ def byte_GEN_START(self, state, op): del op return state.pop_and_discard() - def byte_CACHE(self, state, op): + def byte_CACHE(self, state: _T0, op) -> _T0: # No stack or type effects del op return state @@ -3580,15 +3604,15 @@ def byte_BEFORE_WITH(self, state, op): state, ctxmgr_obj = self._call(state, ctxmgr, "__enter__", ()) return state.push(ctxmgr_obj) - def byte_RETURN_GENERATOR(self, state, op): + def byte_RETURN_GENERATOR(self, state: _T0, op) -> _T0: del op return state - def byte_ASYNC_GEN_WRAP(self, state, op): + def byte_ASYNC_GEN_WRAP(self, state: _T0, op) -> _T0: del op return state - def byte_PREP_RERAISE_STAR(self, state, op): + def byte_PREP_RERAISE_STAR(self, state: _T0, op) -> _T0: del op return state @@ -3684,28 +3708,28 @@ def byte_POP_JUMP_FORWARD_IF_NONE(self, state, op): state, op, self.ctx, jump_if_val=None, pop=vm_utils.PopBehavior.ALWAYS ) - def byte_JUMP_BACKWARD_NO_INTERRUPT(self, state, op): + def byte_JUMP_BACKWARD_NO_INTERRUPT(self, state: _T0, op) -> _T0: self.store_jump(op.target, state.forward_cfg_node("JumpBackward")) return state - def byte_MAKE_CELL(self, state, op): + def byte_MAKE_CELL(self, state: _T0, op) -> _T0: del op return state - def byte_JUMP_BACKWARD(self, state, op): + def byte_JUMP_BACKWARD(self, state: _T0, op) -> _T0: self.store_jump(op.target, state.forward_cfg_node("JumpBackward")) return state - def byte_COPY_FREE_VARS(self, state, op): + def byte_COPY_FREE_VARS(self, state: _T0, op) -> _T0: self.frame.copy_free_vars(op.arg) return state - def byte_RESUME(self, state, op): + def byte_RESUME(self, state: _T0, op) -> _T0: # No stack or type effects del op return state - def byte_PRECALL(self, state, op): + def byte_PRECALL(self, state: _T0, op) -> _T0: # No stack or type effects del op return state @@ -3713,7 +3737,7 @@ def byte_PRECALL(self, state, op): def byte_CALL(self, state, op): return self.call_function_from_stack_311(state, op.arg) - def byte_KW_NAMES(self, state, op): + def byte_KW_NAMES(self, state: _T0, op) -> _T0: # Stores a list of kw names to be retrieved by CALL self._kw_names = op.argval return state @@ -3738,11 +3762,11 @@ def byte_POP_JUMP_BACKWARD_IF_FALSE(self, state, op): def byte_POP_JUMP_BACKWARD_IF_TRUE(self, state, op): return self.byte_POP_JUMP_IF_TRUE(state, op) - def byte_INTERPRETER_EXIT(self, state, op): + def byte_INTERPRETER_EXIT(self, state: _T0, op) -> _T0: del op return state - def byte_END_FOR(self, state, op): + def byte_END_FOR(self, state: _T0, op) -> _T0: # No-op in pytype. See comment in `byte_FOR_ITER` for details. return state @@ -3751,7 +3775,7 @@ def byte_END_SEND(self, state, op): state, top = state.pop() return state.set_top(top) - def byte_RESERVED(self, state, op): + def byte_RESERVED(self, state: _T0, op) -> _T0: del op return state @@ -3768,7 +3792,7 @@ def byte_STORE_SLICE(self, state, op): subscr = self.ctx.convert.build_slice(state.node, start, end) return self.store_subscr(state, obj, subscr, val) - def byte_CLEANUP_THROW(self, state, op): + def byte_CLEANUP_THROW(self, state: _T0, op) -> _T0: # In 3.12 the only use of CLEANUP_THROW is for exception handling in # generators. Pytype elides the opcode in opcodes::_make_opcode_list. del op @@ -3777,7 +3801,7 @@ def byte_CLEANUP_THROW(self, state, op): def byte_LOAD_LOCALS(self, state, op): return state.push(self.frame.f_locals.to_variable(state.node)) - def byte_LOAD_FROM_DICT_OR_GLOBALS(self, state, op): + def byte_LOAD_FROM_DICT_OR_GLOBALS(self, state: _T0, op) -> _T0: # TODO: b/350910471 - Implement to support PEP 695 return state @@ -3829,10 +3853,10 @@ def byte_CALL_INTRINSIC_2(self, state, op): raise VirtualMachineError(f"Unknown intrinsic function: {op.argval}") return intrinsic_fn(state) - def byte_INTRINSIC_1_INVALID(self, state): + def byte_INTRINSIC_1_INVALID(self, state: _T0) -> _T0: return state - def byte_INTRINSIC_PRINT(self, state): + def byte_INTRINSIC_PRINT(self, state: _T0) -> _T0: # Only used in the interactive interpreter, not in modules. return state @@ -3840,11 +3864,11 @@ def byte_INTRINSIC_IMPORT_STAR(self, state): state = self._import_star(state) return self._push_null(state) - def byte_INTRINSIC_STOPITERATION_ERROR(self, state): + def byte_INTRINSIC_STOPITERATION_ERROR(self, state: _T0) -> _T0: # Changes StopIteration or StopAsyncIteration to a RuntimeError. return state - def byte_INTRINSIC_ASYNC_GEN_WRAP(self, state): + def byte_INTRINSIC_ASYNC_GEN_WRAP(self, state: _T0) -> _T0: return state def byte_INTRINSIC_UNARY_POSITIVE(self, state): @@ -3853,41 +3877,41 @@ def byte_INTRINSIC_UNARY_POSITIVE(self, state): def byte_INTRINSIC_LIST_TO_TUPLE(self, state): return self._list_to_tuple(state) - def byte_INTRINSIC_TYPEVAR(self, state): + def byte_INTRINSIC_TYPEVAR(self, state: _T0) -> _T0: # TODO: b/350910471 - Implement to support PEP 695 return state - def byte_INTRINSIC_PARAMSPEC(self, state): + def byte_INTRINSIC_PARAMSPEC(self, state: _T0) -> _T0: # TODO: b/350910471 - Implement to support PEP 695 return state - def byte_INTRINSIC_TYPEVARTUPLE(self, state): + def byte_INTRINSIC_TYPEVARTUPLE(self, state: _T0) -> _T0: # TODO: b/350910471 - Implement to support PEP 695 return state - def byte_INTRINSIC_SUBSCRIPT_GENERIC(self, state): + def byte_INTRINSIC_SUBSCRIPT_GENERIC(self, state: _T0) -> _T0: # TODO: b/350910471 - Implement to support PEP 695 return state - def byte_INTRINSIC_TYPEALIAS(self, state): + def byte_INTRINSIC_TYPEALIAS(self, state: _T0) -> _T0: # TODO: b/350910471 - Implement to support PEP 695 return state - def byte_INTRINSIC_2_INVALID(self, state): + def byte_INTRINSIC_2_INVALID(self, state: _T0) -> _T0: return state - def byte_INTRINSIC_PREP_RERAISE_STAR(self, state): + def byte_INTRINSIC_PREP_RERAISE_STAR(self, state: _T0) -> _T0: return state - def byte_INTRINSIC_TYPEVAR_WITH_BOUND(self, state): + def byte_INTRINSIC_TYPEVAR_WITH_BOUND(self, state: _T0) -> _T0: # TODO: b/350910471 - Implement to support PEP 695 return state - def byte_INTRINSIC_TYPEVAR_WITH_CONSTRAINTS(self, state): + def byte_INTRINSIC_TYPEVAR_WITH_CONSTRAINTS(self, state: _T0) -> _T0: # TODO: b/350910471 - Implement to support PEP 695 return state - def byte_INTRINSIC_SET_FUNCTION_TYPE_PARAMS(self, state): + def byte_INTRINSIC_SET_FUNCTION_TYPE_PARAMS(self, state: _T0) -> _T0: # TODO: b/350910471 - Implement to support PEP 695 return state diff --git a/pytype/vm_utils.py b/pytype/vm_utils.py index 87121d741..7142211a1 100644 --- a/pytype/vm_utils.py +++ b/pytype/vm_utils.py @@ -9,6 +9,7 @@ import logging import re import reprlib +from typing import Any from pytype import overriding_checks from pytype import state as frame_state @@ -27,20 +28,25 @@ from pytype.pytd import slots from pytype.typegraph import cfg -log = logging.getLogger(__name__) + +log: logging.Logger = logging.getLogger(__name__) # Create a repr that won't overflow. _TRUNCATE = 120 _TRUNCATE_STR = 72 -_repr_obj = reprlib.Repr() +_repr_obj: reprlib.Repr = reprlib.Repr() _repr_obj.maxother = _TRUNCATE _repr_obj.maxstring = _TRUNCATE_STR repper = _repr_obj.repr -_FUNCTION_TYPE_COMMENT_RE = re.compile(r"^\((.*)\)\s*->\s*(\S.*?)\s*$") +_FUNCTION_TYPE_COMMENT_RE: re.Pattern = re.compile( + r"^\((.*)\)\s*->\s*(\S.*?)\s*$" +) # Classes supporting pattern matching without an explicit __match_args__ -_BUILTIN_MATCHERS = ( +_BUILTIN_MATCHERS: tuple[ + str, str, str, str, str, str, str, str, str, str, str +] = ( "bool", "bytearray", "bytes", @@ -69,14 +75,14 @@ class _Block: level: int op_index: int - def __repr__(self): + def __repr__(self) -> str: return f"Block({self.type}: {self.op_index} level={self.level})" class FindIgnoredTypeComments(pyc.CodeVisitor): """A visitor that finds type comments that will be ignored.""" - def __init__(self, type_comments): + def __init__(self, type_comments) -> None: super().__init__() self._type_comments = type_comments # Lines will be removed from this set during visiting. Any lines that remain @@ -97,7 +103,7 @@ def visit_code(self, code): self._ignored_type_lines.discard(line) return code - def ignored_lines(self): + def ignored_lines(self) -> set: """Returns a set of lines that contain ignored type comments.""" return self._ignored_type_lines @@ -107,9 +113,9 @@ class FinallyStateTracker: # Used in vm.run_frame() - RETURN_STATES = ("return", "exception") + RETURN_STATES: tuple[str, str] = ("return", "exception") - def __init__(self): + def __init__(self) -> None: self.stack = [] def process(self, op, state, ctx) -> str | None: @@ -132,7 +138,7 @@ def check_early_exit(self, state) -> bool: and state.why in self.RETURN_STATES ) - def __repr__(self): + def __repr__(self) -> str: return repr(self.stack) @@ -146,12 +152,12 @@ def to_error_message(self) -> str: class _NameInInnerClassErrorDetails(_NameErrorDetails): - def __init__(self, attr, class_name): + def __init__(self, attr, class_name) -> None: super().__init__() self._attr = attr self._class_name = class_name - def to_error_message(self): + def to_error_message(self) -> str: return ( f"Cannot reference {self._attr!r} from class {self._class_name!r} " "before the class is fully defined" @@ -161,13 +167,13 @@ def to_error_message(self): class _NameInOuterClassErrorDetails(_NameErrorDetails): """Name error details for a name defined in an outer class.""" - def __init__(self, attr, prefix, class_name): + def __init__(self, attr, prefix, class_name) -> None: super().__init__() self._attr = attr self._prefix = prefix self._class_name = class_name - def to_error_message(self): + def to_error_message(self) -> str: full_attr_name = f"{self._class_name}.{self._attr}" if self._prefix: full_class_name = f"{self._prefix}.{self._class_name}" @@ -182,13 +188,13 @@ def to_error_message(self): class _NameInOuterFunctionErrorDetails(_NameErrorDetails): """Name error details for a name defined in an outer function.""" - def __init__(self, attr, outer_scope, inner_scope): + def __init__(self, attr, outer_scope, inner_scope) -> None: super().__init__() self._attr = attr self._outer_scope = outer_scope self._inner_scope = inner_scope - def to_error_message(self): + def to_error_message(self) -> str: keyword = "global" if "global" in self._outer_scope else "nonlocal" return ( f"Add `{keyword} {self._attr}` in {self._inner_scope} to reference " @@ -372,7 +378,7 @@ def get_name_error_details(state, name: str, ctx) -> _NameErrorDetails | None: return None -def log_opcode(op, state, frame, stack_size): +def log_opcode(op, state, frame, stack_size) -> None: """Write a multi-line log message, including backtrace and stack.""" if not log.isEnabledFor(logging.INFO): return @@ -424,7 +430,7 @@ def _process_base_class(node, base, ctx): return base -def _filter_out_metaclasses(bases, ctx): +def _filter_out_metaclasses(bases, ctx) -> tuple[Any, list]: """Process the temporary classes created by six.with_metaclass. six.with_metaclass constructs an anonymous class holding a metaclass and a @@ -454,7 +460,7 @@ def _filter_out_metaclasses(bases, ctx): return meta, non_meta -def _expand_generic_protocols(node, bases, ctx): +def _expand_generic_protocols(node, bases, ctx) -> list: """Expand Protocol[T, ...] to Protocol, Generic[T, ...].""" expanded_bases = [] for base in bases: @@ -484,7 +490,7 @@ def _expand_generic_protocols(node, bases, ctx): return expanded_bases -def _check_final_members(cls, class_dict, ctx): +def _check_final_members(cls, class_dict, ctx) -> None: """Check if the new class overrides a final attribute or method.""" methods = set(class_dict) for base in cls.mro[1:]: @@ -620,7 +626,7 @@ def make_class(node, props, ctx): return node, var -def _check_defaults(node, method, ctx): +def _check_defaults(node, method, ctx) -> None: """Check parameter defaults against annotations.""" if not method.signature.has_param_annotations: return @@ -631,9 +637,9 @@ def _check_defaults(node, method, ctx): raise AssertionError( "Unexpected argument matching error: %s" % e.__class__.__name__ ) from e - for e, arg_name, value in errors: + for e, arg_name, value in errors: # pytype: disable=attribute-error bad_param = e.bad_call.bad_param - expected_type = bad_param.typ + expected_type = bad_param.typ # pytype: disable=wrong-arg-types if value == ctx.convert.ellipsis: # `...` should be a valid default parameter value for overloads. # Unfortunately, the is_overload attribute is not yet set when @@ -697,7 +703,7 @@ def make_function( return var -def update_excluded_types(node, ctx): +def update_excluded_types(node, ctx) -> None: """Update the excluded_types attribute of functions in the current frame.""" if not ctx.vm.frame.func: return @@ -737,7 +743,7 @@ def _base(cls): return cls -def _overrides(subcls, supercls, attr): +def _overrides(subcls, supercls, attr) -> bool: """Check whether subcls_var overrides or newly defines the given attribute. Args: @@ -766,7 +772,9 @@ def _overrides(subcls, supercls, attr): return False -def _call_binop_on_bindings(node, name, xval, yval, ctx): +def _call_binop_on_bindings( + node, name, xval, yval, ctx +) -> tuple[Any, Any]: """Call a binary operator on two cfg.Binding objects.""" rname = slots.REVERSE_NAME_MAPPING.get(name) if rname and isinstance(xval.data, abstract.AMBIGUOUS_OR_EMPTY): @@ -853,7 +861,9 @@ def _maybe_union(node, x, y, ctx): return abstract.Union(opts, ctx).to_variable(node) -def call_binary_operator(state, name, x, y, report_errors, ctx): +def call_binary_operator( + state, name, x, y, report_errors, ctx +) -> tuple[Any, Any]: """Map a binary operator to "magic methods" (__add__ etc.).""" results = [] log.debug("Calling binary operator %s", name) @@ -907,7 +917,7 @@ def call_binary_operator(state, name, x, y, report_errors, ctx): return state, result -def call_inplace_operator(state, iname, x, y, ctx): +def call_inplace_operator(state, iname, x, y, ctx) -> tuple[Any, Any]: """Try to call a method like __iadd__, possibly fall back to __add__.""" state, attr = ctx.vm.load_attr_noerror(state, x, iname) if attr is None: @@ -930,7 +940,7 @@ def call_inplace_operator(state, iname, x, y, ctx): return state, ret -def check_for_deleted(state, name, var, ctx): +def check_for_deleted(state, name, var, ctx) -> None: for x in var.Data(state.node): if isinstance(x, abstract.Deleted): # Referencing a deleted variable @@ -1043,7 +1053,7 @@ def jump_if(state, op, ctx, *, jump_if_val, pop=PopBehavior.NONE): return state.forward_cfg_node("NoJump", normal.binding if normal else None) -def process_function_type_comment(node, op, func, ctx): +def process_function_type_comment(node, op, func, ctx) -> None: """Modifies annotations from a function type comment. Checks if a type comment is present for the function. If so, the type @@ -1295,7 +1305,7 @@ def copy_dict_without_keys( return ret.to_variable(node) -def unpack_iterable(node, var, ctx): +def unpack_iterable(node, var, ctx) -> list: """Unpack an iterable.""" elements = [] try: @@ -1330,7 +1340,7 @@ def unpack_iterable(node, var, ctx): return elements -def pop_and_unpack_list(state, count, ctx): +def pop_and_unpack_list(state, count, ctx) -> tuple[Any, list]: """Pop count iterables off the stack and concatenate.""" state, iterables = state.popn(count) elements = [] @@ -1339,7 +1349,7 @@ def pop_and_unpack_list(state, count, ctx): return state, elements -def merge_indefinite_iterables(node, target, iterables_to_merge): +def merge_indefinite_iterables(node, target, iterables_to_merge) -> None: for var in iterables_to_merge: if abstract_utils.is_var_splat(var): for val in abstract_utils.unwrap_splat(var).data: @@ -1440,7 +1450,7 @@ def _binding_to_coroutine(state, b, bad_bindings, ret, top, ctx): return state -def to_coroutine(state, obj, top, ctx): +def to_coroutine(state, obj, top, ctx) -> tuple[Any, Any]: """Convert any awaitables and generators in obj to coroutines. Implements the GET_AWAITABLE opcode, which returns obj unchanged if it is a