diff --git a/mypy/semanal.py b/mypy/semanal.py index 4128369ace5d..48be004daf76 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -224,9 +224,9 @@ from mypy.tvar_scope import TypeVarLikeScope from mypy.typeanal import ( SELF_TYPE_NAMES, + FindTypeVarVisitor, TypeAnalyser, TypeVarLikeList, - TypeVarLikeQuery, analyze_type_alias, check_for_explicit_any, detect_diverging_alias, @@ -2034,6 +2034,11 @@ def analyze_unbound_tvar_impl( assert isinstance(sym.node, TypeVarExpr) return t.name, sym.node + def find_type_var_likes(self, t: Type) -> TypeVarLikeList: + visitor = FindTypeVarVisitor(self, self.tvar_scope) + t.accept(visitor) + return visitor.type_var_likes + def get_all_bases_tvars( self, base_type_exprs: list[Expression], removed: list[int] ) -> TypeVarLikeList: @@ -2046,7 +2051,7 @@ def get_all_bases_tvars( except TypeTranslationError: # This error will be caught later. continue - base_tvars = base.accept(TypeVarLikeQuery(self, self.tvar_scope)) + base_tvars = self.find_type_var_likes(base) tvars.extend(base_tvars) return remove_dups(tvars) @@ -2064,7 +2069,7 @@ def get_and_bind_all_tvars(self, type_exprs: list[Expression]) -> list[TypeVarLi except TypeTranslationError: # This error will be caught later. continue - base_tvars = base.accept(TypeVarLikeQuery(self, self.tvar_scope)) + base_tvars = self.find_type_var_likes(base) tvars.extend(base_tvars) tvars = remove_dups(tvars) # Variables are defined in order of textual appearance. tvar_defs = [] @@ -3490,7 +3495,7 @@ def analyze_alias( ) return None, [], set(), [], False - found_type_vars = typ.accept(TypeVarLikeQuery(self, self.tvar_scope)) + found_type_vars = self.find_type_var_likes(typ) tvar_defs: list[TypeVarLikeType] = [] namespace = self.qualified_name(name) with self.tvar_scope_frame(self.tvar_scope.class_frame(namespace)): diff --git a/mypy/typeanal.py b/mypy/typeanal.py index d238a452e7a9..4d916315bddd 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -1570,32 +1570,32 @@ def tvar_scope_frame(self) -> Iterator[None]: yield self.tvar_scope = old_scope - def find_type_var_likes(self, t: Type, include_callables: bool = True) -> TypeVarLikeList: - return t.accept( - TypeVarLikeQuery(self.api, self.tvar_scope, include_callables=include_callables) - ) - - def infer_type_variables(self, type: CallableType) -> list[tuple[str, TypeVarLikeExpr]]: - """Return list of unique type variables referred to in a callable.""" - names: list[str] = [] - tvars: list[TypeVarLikeExpr] = [] + def find_type_var_likes(self, t: Type) -> TypeVarLikeList: + visitor = FindTypeVarVisitor(self.api, self.tvar_scope) + t.accept(visitor) + return visitor.type_var_likes + + def infer_type_variables( + self, type: CallableType + ) -> tuple[list[tuple[str, TypeVarLikeExpr]], bool]: + """Infer type variables from a callable. + + Return tuple with these items: + - list of unique type variables referred to in a callable + - whether there is a reference to the Self type + """ + visitor = FindTypeVarVisitor(self.api, self.tvar_scope) for arg in type.arg_types: - for name, tvar_expr in self.find_type_var_likes(arg): - if name not in names: - names.append(name) - tvars.append(tvar_expr) + arg.accept(visitor) + # When finding type variables in the return type of a function, don't # look inside Callable types. Type variables only appearing in # functions in the return type belong to those functions, not the # function we're currently analyzing. - for name, tvar_expr in self.find_type_var_likes(type.ret_type, include_callables=False): - if name not in names: - names.append(name) - tvars.append(tvar_expr) + visitor.include_callables = False + type.ret_type.accept(visitor) - if not names: - return [] # Fast path - return list(zip(names, tvars)) + return visitor.type_var_likes, visitor.has_self_type def bind_function_type_variables( self, fun_type: CallableType, defn: Context @@ -1615,10 +1615,7 @@ def bind_function_type_variables( binding = self.tvar_scope.bind_new(var.name, var_expr) defs.append(binding) return defs, has_self_type - typevars = self.infer_type_variables(fun_type) - has_self_type = find_self_type( - fun_type, lambda name: self.api.lookup_qualified(name, defn, suppress_errors=True) - ) + typevars, has_self_type = self.infer_type_variables(fun_type) # Do not define a new type variable if already defined in scope. typevars = [ (name, tvar) for name, tvar in typevars if not self.is_defined_type_var(name, defn) @@ -2062,67 +2059,6 @@ def flatten_tvars(lists: list[list[T]]) -> list[T]: return result -class TypeVarLikeQuery(TypeQuery[TypeVarLikeList]): - """Find TypeVar and ParamSpec references in an unbound type.""" - - def __init__( - self, - api: SemanticAnalyzerCoreInterface, - scope: TypeVarLikeScope, - *, - include_callables: bool = True, - ) -> None: - super().__init__(flatten_tvars) - self.api = api - self.scope = scope - self.include_callables = include_callables - # Only include type variables in type aliases args. This would be anyway - # that case if we expand (as target variables would be overridden with args) - # and it may cause infinite recursion on invalid (diverging) recursive aliases. - self.skip_alias_target = True - - def _seems_like_callable(self, type: UnboundType) -> bool: - if not type.args: - return False - return isinstance(type.args[0], (EllipsisType, TypeList, ParamSpecType)) - - def visit_unbound_type(self, t: UnboundType) -> TypeVarLikeList: - name = t.name - node = None - # Special case P.args and P.kwargs for ParamSpecs only. - if name.endswith("args"): - if name.endswith(".args") or name.endswith(".kwargs"): - base = ".".join(name.split(".")[:-1]) - n = self.api.lookup_qualified(base, t) - if n is not None and isinstance(n.node, ParamSpecExpr): - node = n - name = base - if node is None: - node = self.api.lookup_qualified(name, t) - if ( - node - and isinstance(node.node, TypeVarLikeExpr) - and self.scope.get_binding(node) is None - ): - assert isinstance(node.node, TypeVarLikeExpr) - return [(name, node.node)] - elif not self.include_callables and self._seems_like_callable(t): - return [] - elif node and node.fullname in LITERAL_TYPE_NAMES: - return [] - elif node and node.fullname in ANNOTATED_TYPE_NAMES and t.args: - # Don't query the second argument to Annotated for TypeVars - return self.query_types([t.args[0]]) - else: - return super().visit_unbound_type(t) - - def visit_callable_type(self, t: CallableType) -> TypeVarLikeList: - if self.include_callables: - return super().visit_callable_type(t) - else: - return [] - - class DivergingAliasDetector(TrivialSyntheticTypeTranslator): """See docstring of detect_diverging_alias() for details.""" @@ -2359,3 +2295,149 @@ def unknown_unpack(t: Type) -> bool: if isinstance(unpacked, AnyType) and unpacked.type_of_any == TypeOfAny.special_form: return True return False + + +class FindTypeVarVisitor(SyntheticTypeVisitor[None]): + """Type visitor that looks for type variable types and self types.""" + + def __init__(self, api: SemanticAnalyzerCoreInterface, scope: TypeVarLikeScope) -> None: + self.api = api + self.scope = scope + self.type_var_likes: list[tuple[str, TypeVarLikeExpr]] = [] + self.has_self_type = False + self.seen_aliases: set[TypeAliasType] | None = None + self.include_callables = True + + def _seems_like_callable(self, type: UnboundType) -> bool: + if not type.args: + return False + return isinstance(type.args[0], (EllipsisType, TypeList, ParamSpecType)) + + def visit_unbound_type(self, t: UnboundType) -> None: + name = t.name + node = None + + # Special case P.args and P.kwargs for ParamSpecs only. + if name.endswith("args"): + if name.endswith(".args") or name.endswith(".kwargs"): + base = ".".join(name.split(".")[:-1]) + n = self.api.lookup_qualified(base, t) + if n is not None and isinstance(n.node, ParamSpecExpr): + node = n + name = base + if node is None: + node = self.api.lookup_qualified(name, t) + if node and node.fullname in SELF_TYPE_NAMES: + self.has_self_type = True + if ( + node + and isinstance(node.node, TypeVarLikeExpr) + and self.scope.get_binding(node) is None + ): + if (name, node.node) not in self.type_var_likes: + self.type_var_likes.append((name, node.node)) + elif not self.include_callables and self._seems_like_callable(t): + if find_self_type( + t, lambda name: self.api.lookup_qualified(name, t, suppress_errors=True) + ): + self.has_self_type = True + return + elif node and node.fullname in LITERAL_TYPE_NAMES: + return + elif node and node.fullname in ANNOTATED_TYPE_NAMES and t.args: + # Don't query the second argument to Annotated for TypeVars + self.process_types([t.args[0]]) + elif t.args: + self.process_types(t.args) + + def visit_type_list(self, t: TypeList) -> None: + self.process_types(t.items) + + def visit_callable_argument(self, t: CallableArgument) -> None: + t.typ.accept(self) + + def visit_any(self, t: AnyType) -> None: + pass + + def visit_uninhabited_type(self, t: UninhabitedType) -> None: + pass + + def visit_none_type(self, t: NoneType) -> None: + pass + + def visit_erased_type(self, t: ErasedType) -> None: + pass + + def visit_deleted_type(self, t: DeletedType) -> None: + pass + + def visit_type_var(self, t: TypeVarType) -> None: + self.process_types([t.upper_bound, t.default] + t.values) + + def visit_param_spec(self, t: ParamSpecType) -> None: + self.process_types([t.upper_bound, t.default]) + + def visit_type_var_tuple(self, t: TypeVarTupleType) -> None: + self.process_types([t.upper_bound, t.default]) + + def visit_unpack_type(self, t: UnpackType) -> None: + self.process_types([t.type]) + + def visit_parameters(self, t: Parameters) -> None: + self.process_types(t.arg_types) + + def visit_partial_type(self, t: PartialType) -> None: + pass + + def visit_instance(self, t: Instance) -> None: + self.process_types(t.args) + + def visit_callable_type(self, t: CallableType) -> None: + # FIX generics + self.process_types(t.arg_types) + t.ret_type.accept(self) + + def visit_tuple_type(self, t: TupleType) -> None: + self.process_types(t.items) + + def visit_typeddict_type(self, t: TypedDictType) -> None: + self.process_types(list(t.items.values())) + + def visit_raw_expression_type(self, t: RawExpressionType) -> None: + pass + + def visit_literal_type(self, t: LiteralType) -> None: + pass + + def visit_union_type(self, t: UnionType) -> None: + self.process_types(t.items) + + def visit_overloaded(self, t: Overloaded) -> None: + self.process_types(t.items) # type: ignore[arg-type] + + def visit_type_type(self, t: TypeType) -> None: + t.item.accept(self) + + def visit_ellipsis_type(self, t: EllipsisType) -> None: + pass + + def visit_placeholder_type(self, t: PlaceholderType) -> None: + return self.process_types(t.args) + + def visit_type_alias_type(self, t: TypeAliasType) -> None: + # Skip type aliases in already visited types to avoid infinite recursion. + if self.seen_aliases is None: + self.seen_aliases = set() + elif t in self.seen_aliases: + return + self.seen_aliases.add(t) + self.process_types(t.args) + + def process_types(self, types: list[Type] | tuple[Type, ...]) -> None: + # Redundant type check helps mypyc. + if isinstance(types, list): + for t in types: + t.accept(self) + else: + for t in types: + t.accept(self)