Skip to content

Commit

Permalink
Allow self binding for generic ParamSpec
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Apr 14, 2024
1 parent 4310586 commit fd4a62a
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 1 deletion.
81 changes: 80 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
LiteralType,
NoneType,
Overloaded,
Parameters,
ParamSpecType,
PartialType,
ProperType,
Expand Down Expand Up @@ -792,6 +793,33 @@ def analyze_var(
else:
call_type = typ

if isinstance(call_type, Instance) and any(
isinstance(arg, Parameters) for arg in call_type.args
):
args: list[Type] = []
for arg in call_type.args:
if not isinstance(arg, Parameters):
args.append(arg)
continue
c = callable_type_from_parameters(arg, mx.chk.named_type("builtins.function"))
if not var.is_staticmethod:
functype: FunctionLike = c
dispatched_type = meet.meet_types(mx.original_type, itype)
signature = freshen_all_functions_type_vars(functype)
bound = get_proper_type(expand_self_type(var, signature, mx.original_type))
assert isinstance(bound, FunctionLike)
signature = bound
signature = check_self_arg(
signature, dispatched_type, var.is_classmethod, mx.context, name, mx.msg
)
signature = bind_self(signature, mx.self_type, var.is_classmethod)
expanded_signature = expand_type_by_instance(signature, itype)
freeze_all_type_vars(expanded_signature)
assert isinstance(expanded_signature, CallableType)
arg = update_parameters_from_signature(arg, expanded_signature)
args.append(arg)
call_type = call_type.copy_modified(args=args)
result = call_type
if isinstance(call_type, FunctionLike) and not call_type.is_type_obj():
if mx.is_lvalue:
if var.is_property:
Expand All @@ -803,7 +831,7 @@ def analyze_var(
if not var.is_staticmethod:
# Class-level function objects and classmethods become bound methods:
# the former to the instance, the latter to the class.
functype: FunctionLike = call_type
functype = call_type
# Use meet to narrow original_type to the dispatched type.
# For example, assume
# * A.f: Callable[[A1], None] where A1 <: A (maybe A1 == A)
Expand Down Expand Up @@ -1061,6 +1089,30 @@ def analyze_class_attribute_access(
isinstance(node.node, FuncBase) and node.node.is_static
)
t = get_proper_type(t)
if isinstance(t, Instance) and any(isinstance(arg, Parameters) for arg in t.args):
args: list[Type] = []
for arg in t.args:
if not isinstance(arg, Parameters):
args.append(arg)
continue
c: FunctionLike = callable_type_from_parameters(
arg, mx.chk.named_type("builtins.function")
)
if is_classmethod:
c = check_self_arg(c, mx.self_type, False, mx.context, name, mx.msg)
res = add_class_tvars(
c,
isuper,
is_classmethod,
is_staticmethod,
mx.self_type,
original_vars=original_vars,
)
signature = get_proper_type(res)
assert isinstance(signature, CallableType)
arg = update_parameters_from_signature(arg, signature)
args.append(arg)
t = t.copy_modified(args=args)
if isinstance(t, FunctionLike) and is_classmethod:
t = check_self_arg(t, mx.self_type, False, mx.context, name, mx.msg)
result = add_class_tvars(
Expand Down Expand Up @@ -1348,3 +1400,30 @@ def is_valid_constructor(n: SymbolNode | None) -> bool:
if isinstance(n, Decorator):
return isinstance(get_proper_type(n.type), FunctionLike)
return False


def callable_type_from_parameters(
param: Parameters, fallback: Instance, ret_type: Type | None = None
) -> CallableType:
"""Create CallableType from Parameters."""
return CallableType(
arg_types=param.arg_types,
arg_kinds=param.arg_kinds,
arg_names=param.arg_names,
ret_type=ret_type if ret_type is not None else NoneType(),
fallback=fallback,
variables=param.variables,
imprecise_arg_kinds=param.imprecise_arg_kinds,
)


def update_parameters_from_signature(param: Parameters, signature: CallableType) -> Parameters:
"""Update Parameters from signature."""
return param.copy_modified(
arg_types=signature.arg_types,
arg_kinds=signature.arg_kinds,
arg_names=signature.arg_names,
is_ellipsis_args=signature.is_ellipsis_args,
variables=signature.variables,
imprecise_arg_kinds=signature.imprecise_arg_kinds,
)
95 changes: 95 additions & 0 deletions test-data/unit/check-selftype.test
Original file line number Diff line number Diff line change
Expand Up @@ -2071,3 +2071,98 @@ p: Partial
reveal_type(p()) # N: Revealed type is "Never"
p2: Partial2
reveal_type(p2(42)) # N: Revealed type is "builtins.int"

[case testSelfTypeBindingWithGenericParamSpec]
from typing import Generic, Callable
from typing_extensions import ParamSpec

P = ParamSpec("P")

class Wrapper(Generic[P]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: ...

def decorator(f: Callable[P, None]) -> Callable[P, None]: ...
def lru_cache(f: Callable[P, None]) -> Wrapper[P]: ...

class A:
@decorator
def method1(self, val: int) -> None: ...

@lru_cache
def method2(self, val: int) -> None: ...

def test(self) -> None:
reveal_type(self.method1) # N: Revealed type is "def (val: builtins.int)"
reveal_type(self.method2) # N: Revealed type is "__main__.Wrapper[[val: builtins.int]]"
reveal_type(A.method1) # N: Revealed type is "def (self: __main__.A, val: builtins.int)"
reveal_type(A.method2) # N: Revealed type is "__main__.Wrapper[[self: __main__.A, val: builtins.int]]"

self.method1(2)
self.method2(2)
[builtins fixtures/tuple.pyi]

[case testSelfTypeBindingWithGenericParamSpecClassmethod]
from typing import Generic, Callable
from typing_extensions import ParamSpec

P = ParamSpec("P")

class Wrapper(Generic[P]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: ...

def decorator(f: Callable[P, None]) -> Callable[P, None]: ...
def lru_cache(f: Callable[P, None]) -> Wrapper[P]: ...

class A:
@classmethod
@decorator
def method1(cls, val: int) -> None: ...

@classmethod
@lru_cache
def method2(cls, val: int) -> None: ...

def test(self) -> None:
reveal_type(self.method1) # N: Revealed type is "def (val: builtins.int)"
reveal_type(self.method2) # N: Revealed type is "__main__.Wrapper[[val: builtins.int]]"
reveal_type(A.method1) # N: Revealed type is "def (val: builtins.int)"
reveal_type(A.method2) # N: Revealed type is "__main__.Wrapper[[val: builtins.int]]"

self.method1(2)
self.method2(2)
A.method1(2)
A.method2(2)
[builtins fixtures/classmethod.pyi]

[case testSelfTypeBindingWithGenericParamSpecStaticmethod]
from typing import Generic, Callable
from typing_extensions import ParamSpec

P = ParamSpec("P")

class Wrapper(Generic[P]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: ...

def decorator(f: Callable[P, None]) -> Callable[P, None]: ...
def lru_cache(f: Callable[P, None]) -> Wrapper[P]: ...

class A:
@staticmethod
@decorator
def method1(val: int) -> None: ...

@staticmethod
@lru_cache
def method2(val: int) -> None: ...

def test(self) -> None:
reveal_type(self.method1) # N: Revealed type is "def (val: builtins.int)"
reveal_type(self.method2) # N: Revealed type is "__main__.Wrapper[[val: builtins.int]]"
reveal_type(A.method1) # N: Revealed type is "def (val: builtins.int)"
reveal_type(A.method2) # N: Revealed type is "__main__.Wrapper[[val: builtins.int]]"

self.method1(2)
self.method2(2)
A.method1(2)
A.method2(2)
[builtins fixtures/classmethod.pyi]

0 comments on commit fd4a62a

Please sign in to comment.