diff --git a/pyrefly/lib/alt/attr.rs b/pyrefly/lib/alt/attr.rs index e69c1dbcab..96e26cc335 100644 --- a/pyrefly/lib/alt/attr.rs +++ b/pyrefly/lib/alt/attr.rs @@ -50,6 +50,7 @@ use crate::types::read_only::ReadOnlyReason; use crate::types::type_var::Restriction; use crate::types::typed_dict::TypedDict; use crate::types::types::AnyStyle; +use crate::types::types::BoundMethodType; use crate::types::types::Overload; use crate::types::types::SuperObj; use crate::types::types::Type; @@ -434,6 +435,9 @@ enum AttributeBase1 { TypedDict(TypedDict), /// Attribute lookup on a base as part of a subset check against a protocol. ProtocolSubset(Box), + /// Bound methods prefer exposing builtin `types.MethodType` attributes but fall back to the + /// underlying function's attributes when the builtin ones are missing. + BoundMethod(BoundMethodType), } impl AttributeBase1 { @@ -1223,15 +1227,24 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.lookup_attr_from_attribute_base1((**protocol_base).clone(), attr_name, acc) } } - AttributeBase1::TypeQuantified(quantified, class) => { - if let Some(attr) = self.get_bounded_quantified_class_attribute( - quantified.clone(), - class, - attr_name, - ) { - acc.found_class_attribute(attr, base); + AttributeBase1::BoundMethod(bound_func) => { + let method_type_base = + AttributeBase1::ClassInstance(self.stdlib.method_type().clone()); + let found_len = acc.found.len(); + let not_found_len = acc.not_found.len(); + let error_len = acc.internal_error.len(); + self.lookup_attr_from_attribute_base1(method_type_base, attr_name, acc); + if acc.found.len() == found_len { + acc.not_found.truncate(not_found_len); + acc.internal_error.truncate(error_len); + let mut func_bases = Vec::new(); + self.as_attribute_base1(bound_func.clone().as_type(), &mut func_bases); + for base1 in func_bases { + self.lookup_attr_from_attribute_base1(base1, attr_name, acc); + } } else { - acc.not_found(NotFoundOn::ClassObject(class.class_object().dupe(), base)); + acc.not_found.truncate(not_found_len); + acc.internal_error.truncate(error_len); } } @@ -1725,9 +1738,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.stdlib.function_type().clone() }, )), - Type::BoundMethod(_) => acc.push(AttributeBase1::ClassInstance( - self.stdlib.method_type().clone(), - )), + Type::BoundMethod(bound_method) => { + acc.push(AttributeBase1::BoundMethod(bound_method.func.clone())); + } Type::Ellipsis => { if let Some(cls) = self.stdlib.ellipsis_type() { acc.push(AttributeBase1::ClassInstance(cls.clone())) @@ -2149,8 +2162,20 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { AttributeBase1::ClassObject(class) => { self.completions_class(class.class_object(), expected_attribute_name, res) } - AttributeBase1::TypeQuantified(_, class) => { - self.completions_class(class.class_object(), expected_attribute_name, res) + AttributeBase1::BoundMethod(bound_func) => { + let before = res.len(); + self.completions_class_type( + self.stdlib.method_type(), + expected_attribute_name, + res, + ); + if res.len() == before { + let mut func_bases = Vec::new(); + self.as_attribute_base1(bound_func.clone().as_type(), &mut func_bases); + for base1 in func_bases { + self.completions_inner1(&base1, expected_attribute_name, res); + } + } } AttributeBase1::TypeAny(_) | AttributeBase1::TypeNever => self.completions_class_type( self.stdlib.builtins_type(), diff --git a/pyrefly/lib/test/descriptors.rs b/pyrefly/lib/test/descriptors.rs index 147bf6ea5e..c7b44aa073 100644 --- a/pyrefly/lib/test/descriptors.rs +++ b/pyrefly/lib/test/descriptors.rs @@ -206,6 +206,41 @@ C().d = "42" "#, ); +testcase!( + test_bound_method_preserves_function_attributes_from_descriptor, + r#" +from __future__ import annotations + +from typing import Callable + + +class CachedMethod: + def __init__(self, fn: Callable[[Constraint], int]) -> None: + self._fn = fn + + def __get__(self, obj: Constraint | None, owner: type[Constraint]) -> CachedMethod: + return self + + def __call__(self, obj: Constraint) -> int: + return self._fn(obj) + + def clear_cache(self, obj: Constraint) -> None: ... + + +def cache_on_self(fn: Callable[[Constraint], int]) -> CachedMethod: + return CachedMethod(fn) + + +class Constraint: + @cache_on_self + def pointwise_read_writes(self) -> int: + return 0 + + def clear_cache(self) -> None: + self.pointwise_read_writes.clear_cache(self) + "#, +); + testcase!( test_class_property_descriptor, r#"