From d6824ee104bb0bade2dd62e0f701b2e7b0e41f19 Mon Sep 17 00:00:00 2001 From: Jordandev678 <20153053+Jordandev678@users.noreply.github.com> Date: Wed, 18 Sep 2024 18:34:52 +0000 Subject: [PATCH] Add special handling for typing.get_args --- mypy/checkexpr.py | 33 +++++++++++ test-data/unit/check-get-args.test | 75 +++++++++++++++++++++++++ test-data/unit/fixtures/typing-full.pyi | 2 + 3 files changed, 110 insertions(+) create mode 100644 test-data/unit/check-get-args.test diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 22595c85e702..5150d0f4a735 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -544,6 +544,39 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) -> self.msg.cannot_use_function_with_type(e.callee.name, "TypedDict", e) elif typ.node.is_newtype: self.msg.cannot_use_function_with_type(e.callee.name, "NewType", e) + if ( + isinstance(e.callee, NameExpr) + and e.callee.fullname == "typing.get_args" + and len(e.args) == 1 + ): + #Special hanlding for get_args(), returns a typed tuple + #with the type set by the input + typ = None + if isinstance(e.args[0], IndexExpr): + self.accept(e.args[0].index) + typ = self.chk.lookup_type(e.args[0].index) + else: + try: + node = self.chk.lookup_qualified(e.args[0].name) + except KeyError: + # Undefined names should already be reported in semantic analysis. + pass + if node: + if isinstance(node.node, TypeAlias): + #Resolve type + typ = get_proper_type(node.node.target) + else: + typ = node.node.type + if ( typ is not None + and isinstance(typ, UnionType) + and all([isinstance(t, LiteralType) for t in typ.items]) + ): + # Returning strings is defined but order isn't so + # we need to return type * len of the union + return TupleType([typ] * len(typ.items), fallback=self.named_type("builtins.tuple")) + else: + # Fall back to what we did anyway (Tuple[Any]) + return TupleType([AnyType(TypeOfAny.special_form)], fallback=self.named_type("builtins.tuple")) self.try_infer_partial_type(e) type_context = None if isinstance(e.callee, LambdaExpr): diff --git a/test-data/unit/check-get-args.test b/test-data/unit/check-get-args.test new file mode 100644 index 000000000000..0d40c0a4e83b --- /dev/null +++ b/test-data/unit/check-get-args.test @@ -0,0 +1,75 @@ +[case getArgsReturnTypes] +from typing import Literal, Union, get_args +literals: Literal["a", "bc"] = "a" +unionliterals: Union[Literal["a"], Literal["bc"]] = "a" +intandliterals: Union[Literal["a", "bc"], int] = "a" +intandunionliterals: Union[Literal["a"], Literal["bc"], int] = "a" +inttype: int = 1 +reveal_type(get_args(literals)) # N: Revealed type is "Tuple[Union[Literal['a'], Literal['bc']], Union[Literal['a'], Literal['bc']]]" +reveal_type(get_args(unionliterals)) # N: Revealed type is "Tuple[Union[Literal['a'], Literal['bc']], Union[Literal['a'], Literal['bc']]]" +reveal_type(get_args(intandliterals)) # N: Revealed type is "Tuple[Any]" +reveal_type(get_args(intandunionliterals)) # N: Revealed type is "Tuple[Any]" +reveal_type(get_args(inttype)) # N: Revealed type is "Tuple[Any]" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + +[case getArgsVarTypesWithNarrowing] +from typing import Literal, get_args +from typing_extensions import TypeAlias +normalImplicit = Literal["a", "bc"] +normalExplicit: TypeAlias = Literal["a", "bc"] +reveal_type(get_args(normalImplicit)) # N: Revealed type is "Tuple[Union[Literal['a'], Literal['bc']], Union[Literal['a'], Literal['bc']]]" +reveal_type(get_args(normalExplicit)) # N: Revealed type is "Tuple[Union[Literal['a'], Literal['bc']], Union[Literal['a'], Literal['bc']]]" +#reveal_type(get_args(Literal["a", "bc"])) +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + +[case testNarrowingInType] +from typing import Literal, get_args +type_alpha = Literal["a", "b", "c"] +strIn: str = "c" +strOut: str = "d" +if strIn in get_args(type_alpha): + reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]" +else: + reveal_type(strIn) # N: Revealed type is "builtins.str" +if strOut in get_args(type_alpha): + reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]" +else: + reveal_type(strOut) # N: Revealed type is "builtins.str" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + +[case testNarrowingNotInType] +from typing import Literal, get_args +type_alpha = Literal["a", "b", "c"] +strIn: str = "c" +strOut: str = "d" +if strIn not in get_args(type_alpha): + reveal_type(strIn) # N: Revealed type is "builtins.str" +else: + reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]" +if strOut not in get_args(type_alpha): + reveal_type(strOut) # N: Revealed type is "builtins.str" +else: + reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + +[case i15106] +from typing import Literal, get_args, Optional +from typing_extensions import TypeAlias +ExpectedUserInput: TypeAlias = Literal[ + "these", "strings", "are", "expected", "user", "input" +] +def external_function(input: str) -> Optional[str]: + if input not in get_args(ExpectedUserInput): + reveal_type(input) # N: Revealed type is "builtins.str" + return None + reveal_type(input) # N: Revealed type is "Union[Literal['these'], Literal['strings'], Literal['are'], Literal['expected'], Literal['user'], Literal['input']]" + return _internal_function(input) + +def _internal_function(input: ExpectedUserInput) -> str: + return "User input: {input}" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] \ No newline at end of file diff --git a/test-data/unit/fixtures/typing-full.pyi b/test-data/unit/fixtures/typing-full.pyi index 8e0116aab1c2..e732d532a537 100644 --- a/test-data/unit/fixtures/typing-full.pyi +++ b/test-data/unit/fixtures/typing-full.pyi @@ -204,6 +204,8 @@ def dataclass_transform( ) -> Callable[[T], T]: ... def override(__arg: T) -> T: ... +def get_args(tp: T) -> T: ... + # Was added in 3.11 def reveal_type(__obj: T) -> T: ...