Skip to content

Commit

Permalink
Add special handling for typing.get_args
Browse files Browse the repository at this point in the history
  • Loading branch information
Jordandev678 committed Sep 18, 2024
1 parent 9d7a042 commit d6824ee
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 0 deletions.
33 changes: 33 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,39 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
self.msg.cannot_use_function_with_type(e.callee.name, "TypedDict", e)
elif typ.node.is_newtype:
self.msg.cannot_use_function_with_type(e.callee.name, "NewType", e)
if (
isinstance(e.callee, NameExpr)
and e.callee.fullname == "typing.get_args"
and len(e.args) == 1
):
#Special hanlding for get_args(), returns a typed tuple
#with the type set by the input
typ = None
if isinstance(e.args[0], IndexExpr):
self.accept(e.args[0].index)
typ = self.chk.lookup_type(e.args[0].index)
else:
try:
node = self.chk.lookup_qualified(e.args[0].name)
except KeyError:
# Undefined names should already be reported in semantic analysis.
pass
if node:
if isinstance(node.node, TypeAlias):
#Resolve type
typ = get_proper_type(node.node.target)
else:
typ = node.node.type
if ( typ is not None
and isinstance(typ, UnionType)
and all([isinstance(t, LiteralType) for t in typ.items])
):
# Returning strings is defined but order isn't so
# we need to return type * len of the union
return TupleType([typ] * len(typ.items), fallback=self.named_type("builtins.tuple"))
else:
# Fall back to what we did anyway (Tuple[Any])
return TupleType([AnyType(TypeOfAny.special_form)], fallback=self.named_type("builtins.tuple"))
self.try_infer_partial_type(e)
type_context = None
if isinstance(e.callee, LambdaExpr):
Expand Down
75 changes: 75 additions & 0 deletions test-data/unit/check-get-args.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
[case getArgsReturnTypes]
from typing import Literal, Union, get_args
literals: Literal["a", "bc"] = "a"
unionliterals: Union[Literal["a"], Literal["bc"]] = "a"
intandliterals: Union[Literal["a", "bc"], int] = "a"
intandunionliterals: Union[Literal["a"], Literal["bc"], int] = "a"
inttype: int = 1
reveal_type(get_args(literals)) # N: Revealed type is "Tuple[Union[Literal['a'], Literal['bc']], Union[Literal['a'], Literal['bc']]]"
reveal_type(get_args(unionliterals)) # N: Revealed type is "Tuple[Union[Literal['a'], Literal['bc']], Union[Literal['a'], Literal['bc']]]"
reveal_type(get_args(intandliterals)) # N: Revealed type is "Tuple[Any]"
reveal_type(get_args(intandunionliterals)) # N: Revealed type is "Tuple[Any]"
reveal_type(get_args(inttype)) # N: Revealed type is "Tuple[Any]"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case getArgsVarTypesWithNarrowing]
from typing import Literal, get_args
from typing_extensions import TypeAlias
normalImplicit = Literal["a", "bc"]
normalExplicit: TypeAlias = Literal["a", "bc"]
reveal_type(get_args(normalImplicit)) # N: Revealed type is "Tuple[Union[Literal['a'], Literal['bc']], Union[Literal['a'], Literal['bc']]]"
reveal_type(get_args(normalExplicit)) # N: Revealed type is "Tuple[Union[Literal['a'], Literal['bc']], Union[Literal['a'], Literal['bc']]]"
#reveal_type(get_args(Literal["a", "bc"]))
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testNarrowingInType]
from typing import Literal, get_args
type_alpha = Literal["a", "b", "c"]
strIn: str = "c"
strOut: str = "d"
if strIn in get_args(type_alpha):
reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
else:
reveal_type(strIn) # N: Revealed type is "builtins.str"
if strOut in get_args(type_alpha):
reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
else:
reveal_type(strOut) # N: Revealed type is "builtins.str"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testNarrowingNotInType]
from typing import Literal, get_args
type_alpha = Literal["a", "b", "c"]
strIn: str = "c"
strOut: str = "d"
if strIn not in get_args(type_alpha):
reveal_type(strIn) # N: Revealed type is "builtins.str"
else:
reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
if strOut not in get_args(type_alpha):
reveal_type(strOut) # N: Revealed type is "builtins.str"
else:
reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case i15106]
from typing import Literal, get_args, Optional
from typing_extensions import TypeAlias
ExpectedUserInput: TypeAlias = Literal[
"these", "strings", "are", "expected", "user", "input"
]
def external_function(input: str) -> Optional[str]:
if input not in get_args(ExpectedUserInput):
reveal_type(input) # N: Revealed type is "builtins.str"
return None
reveal_type(input) # N: Revealed type is "Union[Literal['these'], Literal['strings'], Literal['are'], Literal['expected'], Literal['user'], Literal['input']]"
return _internal_function(input)

def _internal_function(input: ExpectedUserInput) -> str:
return "User input: {input}"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]
2 changes: 2 additions & 0 deletions test-data/unit/fixtures/typing-full.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def dataclass_transform(
) -> Callable[[T], T]: ...
def override(__arg: T) -> T: ...

def get_args(tp: T) -> T: ...

# Was added in 3.11
def reveal_type(__obj: T) -> T: ...

Expand Down

0 comments on commit d6824ee

Please sign in to comment.