Skip to content

Commit

Permalink
fix actual_ret_type for TypeGuard
Browse files Browse the repository at this point in the history
  • Loading branch information
sloboegen committed Apr 25, 2024
1 parent 796b51d commit 689bc95
Showing 1 changed file with 37 additions and 7 deletions.
44 changes: 37 additions & 7 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
CONTRAVARIANT,
COVARIANT,
ArgKind,
Block,
ClassDef,
SymbolTable,
TypeInfo,
)
from mypy.types import (
Expand Down Expand Up @@ -1018,25 +1021,18 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
param_spec = template.param_spec()

template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
bool_type = UnionType(
[LiteralType(True, cactual_ret_type), LiteralType(False, cactual_ret_type)] # type: ignore[arg-type]
)

if template.type_guard is not None and cactual.type_guard is not None:
template_ret_type = template.type_guard
cactual_ret_type = cactual.type_guard
elif template.type_guard is not None:
template_ret_type = AnyType(TypeOfAny.special_form)
elif cactual.type_guard is not None:
cactual_ret_type = bool_type

if template.type_is is not None and cactual.type_is is not None:
template_ret_type = template.type_is
cactual_ret_type = cactual.type_is
elif template.type_is is not None:
template_ret_type = AnyType(TypeOfAny.special_form)
elif cactual.type_is is not None:
cactual_ret_type = bool_type

res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction))

Expand Down Expand Up @@ -1340,6 +1336,40 @@ def visit_type_type(self, template: TypeType) -> list[Constraint]:
else:
return []

def _make_type_info(
self,
name: str,
module_name: str | None = None,
mro: list[TypeInfo] | None = None,
bases: list[Instance] | None = None,
) -> TypeInfo:
"""Make a TypeInfo suitable for use in unit tests."""

class_def = ClassDef(name, Block([]), None, [])
class_def.fullname = name

if module_name is None:
if "." in name:
module_name = name.rsplit(".", 1)[0]
else:
module_name = "__main__"

info = TypeInfo(SymbolTable(), class_def, module_name)
if mro is None:
mro = []
if name != "builtins.object":
mro.append(self.oi)
info.mro = [info] + mro
if bases is None:
if mro:
# By default, assume that there is a single non-generic base.
bases = [Instance(mro[0], [])]
else:
bases = []
info.bases = bases

return info


def neg_op(op: int) -> int:
"""Map SubtypeOf to SupertypeOf and vice versa."""
Expand Down

0 comments on commit 689bc95

Please sign in to comment.