Skip to content

Commit

Permalink
Merge pull request numpy#27681 from jorenham/typing/non-existant-scal…
Browse files Browse the repository at this point in the history
…ar-methods

TYP: Fix some inconsistencies in the scalar methods and properties
  • Loading branch information
charris authored Nov 11, 2024
2 parents 7c0e2e4 + 38814d9 commit 20d051a
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 84 deletions.
145 changes: 63 additions & 82 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ from typing import (
# library include `typing_extensions` stubs:
# https://github.com/python/typeshed/blob/main/stdlib/typing_extensions.pyi
from _typeshed import StrOrBytesPath, SupportsFlush, SupportsLenAndGetItem, SupportsWrite
from typing_extensions import CapsuleType, Generic, LiteralString, Protocol, Self, TypeVar, overload
from typing_extensions import CapsuleType, Generic, LiteralString, Protocol, Self, TypeVar, deprecated, overload

from numpy import (
core,
Expand Down Expand Up @@ -1377,6 +1377,10 @@ _SortSide: TypeAlias = L["left", "right"]

@type_check_only
class _ArrayOrScalarCommon:
@property
def real(self, /) -> Any: ...
@property
def imag(self, /) -> Any: ...
@property
def T(self) -> Self: ...
@property
Expand All @@ -1391,17 +1395,18 @@ class _ArrayOrScalarCommon:
def nbytes(self) -> int: ...
@property
def device(self) -> L["cpu"]: ...
def __bool__(self) -> builtins.bool: ...
def __bytes__(self) -> bytes: ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...

def __bool__(self, /) -> builtins.bool: ...
def __int__(self, /) -> int: ...
def __float__(self, /) -> float: ...
def __copy__(self) -> Self: ...
def __deepcopy__(self, memo: None | dict[int, Any], /) -> Self: ...

# TODO: How to deal with the non-commutative nature of `==` and `!=`?
# xref numpy/numpy#17368
def __eq__(self, other: Any, /) -> Any: ...
def __ne__(self, other: Any, /) -> Any: ...

def copy(self, order: _OrderKACF = ...) -> Self: ...
def dump(self, file: StrOrBytesPath | SupportsWrite[bytes]) -> None: ...
def dumps(self) -> bytes: ...
Expand All @@ -1418,7 +1423,7 @@ class _ArrayOrScalarCommon:
@property
def __array_priority__(self) -> float: ...
@property
def __array_struct__(self) -> Any: ... # builtins.PyCapsule
def __array_struct__(self) -> CapsuleType: ... # builtins.PyCapsule
def __array_namespace__(self, /, *, api_version: _ArrayAPIVersion | None = None) -> ModuleType: ...
def __setstate__(self, state: tuple[
SupportsIndex, # version
Expand Down Expand Up @@ -2230,8 +2235,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
) -> NDArray[Any]: ...

def __index__(self: NDArray[np.integer[Any]], /) -> int: ...
def __int__(self: NDArray[number[Any] | np.bool | object_], /) -> int: ...
def __float__(self: NDArray[number[Any] | np.bool | object_], /) -> float: ...
def __int__(self: NDArray[number[Any] | np.timedelta64 | np.bool | object_], /) -> int: ...
def __float__(self: NDArray[number[Any] | np.timedelta64 | np.bool | object_], /) -> float: ...
def __complex__(self: NDArray[number[Any] | np.bool | object_], /) -> complex: ...

def __len__(self) -> int: ...
Expand Down Expand Up @@ -3254,14 +3259,7 @@ class generic(_ArrayOrScalarCommon):
def dtype(self) -> _dtype[Self]: ...

class number(generic, Generic[_NBit1]): # type: ignore
@property
def real(self) -> Self: ...
@property
def imag(self) -> Self: ...
def __class_getitem__(cls, item: Any, /) -> GenericAlias: ...
def __int__(self) -> int: ...
def __float__(self) -> float: ...
def __complex__(self) -> complex: ...
def __neg__(self) -> Self: ...
def __pos__(self) -> Self: ...
def __abs__(self) -> Self: ...
Expand All @@ -3283,19 +3281,19 @@ class number(generic, Generic[_NBit1]): # type: ignore
__gt__: _ComparisonOpGT[_NumberLike_co, _ArrayLikeNumber_co]
__ge__: _ComparisonOpGE[_NumberLike_co, _ArrayLikeNumber_co]

class bool(generic):
def __init__(self, value: object = ..., /) -> None: ...
def item(
self, args: L[0] | tuple[()] | tuple[L[0]] = ..., /,
) -> builtins.bool: ...
def tolist(self) -> builtins.bool: ...
@type_check_only
class _RealMixin:
@property
def real(self) -> Self: ...
@property
def imag(self) -> Self: ...
def __int__(self) -> int: ...
def __float__(self) -> float: ...
def __complex__(self) -> complex: ...

class bool(_RealMixin, generic):
def __init__(self, value: object = ..., /) -> None: ...
def item(self, args: L[0] | tuple[()] | tuple[L[0]] = ..., /) -> builtins.bool: ...
def tolist(self) -> builtins.bool: ...
@deprecated("In future, it will be an error for 'np.bool' scalars to be interpreted as an index")
def __index__(self, /) -> int: ...
def __abs__(self) -> Self: ...
__add__: _BoolOp[np.bool]
__radd__: _BoolOp[np.bool]
Expand Down Expand Up @@ -3332,13 +3330,12 @@ class bool(generic):
bool_: TypeAlias = bool

_StringType = TypeVar("_StringType", bound=str | bytes)
_ShapeType = TypeVar("_ShapeType", bound=_Shape)
_ObjectType = TypeVar("_ObjectType", bound=object)

# The `object_` constructor returns the passed object, so instances with type
# `object_` cannot exists (at runtime).
@final
class object_(generic):
class object_(_RealMixin, generic):
@overload
def __new__(cls, nothing_to_see_here: None = ..., /) -> None: ...
@overload
Expand All @@ -3353,16 +3350,6 @@ class object_(generic):
@overload
def __new__(cls, value: Any = ..., /) -> object | NDArray[object_]: ...

@property
def real(self) -> Self: ...
@property
def imag(self) -> Self: ...
# The 3 protocols below may or may not raise,
# depending on the underlying object
def __int__(self) -> int: ...
def __float__(self) -> float: ...
def __complex__(self) -> complex: ...

if sys.version_info >= (3, 12):
def __release_buffer__(self, buffer: memoryview, /) -> None: ...

Expand All @@ -3379,7 +3366,7 @@ class _DatetimeScalar(Protocol):

# TODO: `item`/`tolist` returns either `dt.date`, `dt.datetime` or `int`
# depending on the unit
class datetime64(generic):
class datetime64(_RealMixin, generic):
@overload
def __init__(
self,
Expand Down Expand Up @@ -3417,25 +3404,29 @@ _ComplexValue: TypeAlias = (
| complex # `complex` is not a subtype of `SupportsComplex`
)

class integer(number[_NBit1]): # type: ignore
@type_check_only
class _RoundMixin:
@overload
def __round__(self, /, ndigits: None = None) -> int: ...
@overload
def __round__(self, /, ndigits: SupportsIndex) -> Self: ...

@type_check_only
class _IntegralMixin(_RealMixin):
@property
def numerator(self) -> Self: ...
@property
def denominator(self) -> L[1]: ...
@overload
def __round__(self, ndigits: None = ..., /) -> int: ...
@overload
def __round__(self, ndigits: SupportsIndex, /) -> Self: ...

# NOTE: `__index__` is technically defined in the bottom-most
# sub-classes (`int64`, `uint32`, etc)
def item(
self, args: L[0] | tuple[()] | tuple[L[0]] = ..., /,
) -> int: ...
class integer(_IntegralMixin, _RoundMixin, number[_NBit1]): # type: ignore
def is_integer(self, /) -> L[True]: ...
def item(self, args: L[0] | tuple[()] | tuple[L[0]] = ..., /) -> int: ...
def tolist(self) -> int: ...
def is_integer(self) -> L[True]: ...
def bit_count(self) -> int: ...
def __index__(self) -> int: ...

# NOTE: `bit_count` and `__index__` are technically defined in the concrete subtypes
def bit_count(self, /) -> int: ...
def __index__(self, /) -> int: ...

__truediv__: _IntTrueDiv[_NBit1]
__rtruediv__: _IntTrueDiv[_NBit1]
def __mod__(self, value: _IntLike_co, /) -> integer[Any]: ...
Expand Down Expand Up @@ -3495,23 +3486,16 @@ longlong = signedinteger[_NBitLongLong]

# TODO: `item`/`tolist` returns either `dt.timedelta` or `int`
# depending on the unit
class timedelta64(generic):
class timedelta64(_IntegralMixin, generic):
def __init__(
self,
value: None | int | _CharLike_co | dt.timedelta | timedelta64 = ...,
format: _CharLike_co | tuple[_CharLike_co, _IntLike_co] = ...,
/,
) -> None: ...
@property
def numerator(self) -> Self: ...
@property
def denominator(self) -> L[1]: ...

# NOTE: Only a limited number of units support conversion
# to builtin scalar types: `Y`, `M`, `ns`, `ps`, `fs`, `as`
def __int__(self) -> int: ...
def __float__(self) -> float: ...
def __complex__(self) -> complex: ...
def __neg__(self) -> Self: ...
def __pos__(self) -> Self: ...
def __abs__(self) -> Self: ...
Expand Down Expand Up @@ -3577,18 +3561,15 @@ ulonglong: TypeAlias = unsignedinteger[_NBitLongLong]

class inexact(number[_NBit1]): ... # type: ignore[misc]

_IntType = TypeVar("_IntType", bound=integer[Any])

class floating(inexact[_NBit1]):
class floating(_RealMixin, _RoundMixin, inexact[_NBit1]):
def __init__(self, value: _FloatValue = ..., /) -> None: ...
def item(self, args: L[0] | tuple[()] | tuple[L[0]] = ..., /) -> float: ...
def tolist(self) -> float: ...
def is_integer(self) -> builtins.bool: ...
def as_integer_ratio(self) -> tuple[int, int]: ...
@overload
def __round__(self, ndigits: None = ..., /) -> int: ...
@overload
def __round__(self, ndigits: SupportsIndex, /) -> Self: ...

# NOTE: `is_integer` and `as_integer_ratio` are technically defined in the concrete subtypes
def is_integer(self, /) -> builtins.bool: ...
def as_integer_ratio(self, /) -> tuple[int, int]: ...

__add__: _FloatOp[_NBit1]
__radd__: _FloatOp[_NBit1]
__sub__: _FloatOp[_NBit1]
Expand Down Expand Up @@ -3617,7 +3598,7 @@ class float64(floating[_64Bit], float): # type: ignore[misc]
def __getformat__(self, typestr: L["double", "float"], /) -> str: ...
def __getnewargs__(self, /) -> tuple[float]: ...

# overrides for `floating` and `builtins.float` compatibility
# overrides for `floating` and `builtins.float` compatibility (`_RealMixin` doesn't work)
@property
def real(self) -> Self: ...
@property
Expand Down Expand Up @@ -3754,9 +3735,16 @@ class complexfloating(inexact[_NBit1], Generic[_NBit1, _NBit2]):
def real(self) -> floating[_NBit1]: ... # type: ignore[override]
@property
def imag(self) -> floating[_NBit2]: ... # type: ignore[override]
def __abs__(self) -> floating[_NBit1 | _NBit2]: ... # type: ignore[override]
# NOTE: Deprecated
# def __round__(self, ndigits=...): ...

# NOTE: `__complex__` is technically defined in the concrete subtypes
def __complex__(self, /) -> complex: ...
def __abs__(self, /) -> floating[_NBit1 | _NBit2]: ... # type: ignore[override]
@deprecated(
"The Python built-in `round` is deprecated for complex scalars, and will raise a `TypeError` in a future release. "
"Use `np.round` or `scalar.round` instead."
)
def __round__(self, /, ndigits: SupportsIndex | None = None) -> Self: ...

@overload
def __add__(self, other: _Complex64_co, /) -> complexfloating[_NBit1, _NBit2]: ...
@overload
Expand Down Expand Up @@ -3871,7 +3859,7 @@ csingle: TypeAlias = complexfloating[_NBitSingle, _NBitSingle]
cdouble: TypeAlias = complexfloating[_NBitDouble, _NBitDouble]
clongdouble: TypeAlias = complexfloating[_NBitLongDouble, _NBitLongDouble]

class flexible(generic): ... # type: ignore
class flexible(_RealMixin, generic): ... # type: ignore

# TODO: `item`/`tolist` returns either `bytes` or `tuple`
# depending on whether or not it's used as an opaque bytes sequence
Expand All @@ -3881,13 +3869,7 @@ class void(flexible):
def __init__(self, value: _IntLike_co | bytes, /, dtype : None = ...) -> None: ...
@overload
def __init__(self, value: Any, /, dtype: _DTypeLikeVoid) -> None: ...
@property
def real(self) -> Self: ...
@property
def imag(self) -> Self: ...
def setfield(
self, val: ArrayLike, dtype: DTypeLike, offset: int = ...
) -> None: ...
def setfield(self, val: ArrayLike, dtype: DTypeLike, offset: int = ...) -> None: ...
@overload
def __getitem__(self, key: str | SupportsIndex, /) -> Any: ...
@overload
Expand All @@ -3899,9 +3881,7 @@ class void(flexible):
/,
) -> None: ...

class character(flexible): # type: ignore
def __int__(self) -> int: ...
def __float__(self) -> float: ...
class character(flexible): ... # type: ignore

# NOTE: Most `np.bytes_` / `np.str_` methods return their
# builtin `bytes` / `str` counterpart
Expand All @@ -3913,6 +3893,7 @@ class bytes_(character, bytes):
def __init__(
self, value: str, /, encoding: str = ..., errors: str = ...
) -> None: ...
def __bytes__(self, /) -> bytes: ...
def item(
self, args: L[0] | tuple[()] | tuple[L[0]] = ..., /,
) -> bytes: ...
Expand Down
2 changes: 0 additions & 2 deletions numpy/typing/tests/data/fail/scalars.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def func(a: np.float32) -> None: ...
func(f2) # E: incompatible type
func(f8) # E: incompatible type

round(c8) # E: No overload variant

c8.__getnewargs__() # E: Invalid self argument
f2.__getnewargs__() # E: Invalid self argument
f2.hex() # E: Invalid self argument
Expand Down

0 comments on commit 20d051a

Please sign in to comment.