Skip to content

Commit

Permalink
Merge pull request numpy#27653 from jorenham/typing/ndarray-array-api
Browse files Browse the repository at this point in the history
TYP: Fix Array API method signatures
  • Loading branch information
charris authored Oct 28, 2024
2 parents 7ed62d2 + 8d0a319 commit 70fde29
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 26 deletions.
37 changes: 16 additions & 21 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import array as _array
import datetime as dt
import enum
from abc import abstractmethod
from types import EllipsisType, TracebackType, MappingProxyType, GenericAlias
from types import EllipsisType, ModuleType, TracebackType, MappingProxyType, GenericAlias
from decimal import Decimal
from fractions import Fraction
from uuid import UUID
Expand Down Expand Up @@ -210,7 +210,7 @@ from typing import (
# This is because the `typeshed` stubs for the standard library include
# `typing_extensions` stubs:
# https://github.com/python/typeshed/blob/main/stdlib/typing_extensions.pyi
from typing_extensions import Generic, LiteralString, Protocol, Self, TypeVar, overload
from typing_extensions import CapsuleType, Generic, LiteralString, Protocol, Self, TypeVar, overload

from numpy import (
core,
Expand Down Expand Up @@ -763,7 +763,7 @@ class _SupportsWrite(Protocol[_AnyStr_contra]):
def write(self, s: _AnyStr_contra, /) -> object: ...

__version__: LiteralString
__array_api_version__: LiteralString
__array_api_version__: Final = "2023.12"
test: PytestTester


Expand Down Expand Up @@ -1431,7 +1431,7 @@ class _ArrayOrScalarCommon:
def __array_priority__(self) -> float: ...
@property
def __array_struct__(self) -> Any: ... # builtins.PyCapsule
def __array_namespace__(self, *, api_version: None | _ArrayAPIVersion = ...) -> Any: ...
def __array_namespace__(self, /, *, api_version: _ArrayAPIVersion | None = None) -> ModuleType: ...
def __setstate__(self, state: tuple[
SupportsIndex, # version
_ShapeLike, # Shape
Expand Down Expand Up @@ -1798,11 +1798,6 @@ _ArrayTD64_co: TypeAlias = NDArray[np.bool | integer[Any] | timedelta64]
# Introduce an alias for `dtype` to avoid naming conflicts.
_dtype: TypeAlias = dtype[_ScalarType]

if sys.version_info >= (3, 13):
from types import CapsuleType as _PyCapsule
else:
_PyCapsule: TypeAlias = Any

_ArrayAPIVersion: TypeAlias = L["2021.12", "2022.12", "2023.12"]

@type_check_only
Expand Down Expand Up @@ -3063,14 +3058,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):

def __dlpack__(
self: NDArray[number[Any]],
/,
*,
stream: int | Any | None = ...,
max_version: tuple[int, int] | None = ...,
dl_device: tuple[int, L[0]] | None = ...,
copy: bool | None = ...,
) -> _PyCapsule: ...

def __dlpack_device__(self) -> tuple[int, L[0]]: ...
stream: int | Any | None = None,
max_version: tuple[int, int] | None = None,
dl_device: tuple[int, int] | None = None,
copy: builtins.bool | None = None,
) -> CapsuleType: ...
def __dlpack_device__(self, /) -> tuple[L[1], L[0]]: ...

def bitwise_count(
self,
Expand Down Expand Up @@ -4727,12 +4722,12 @@ class matrix(ndarray[_Shape2DType_co, _DType_co]):

@type_check_only
class _SupportsDLPack(Protocol[_T_contra]):
def __dlpack__(self, *, stream: None | _T_contra = ...) -> _PyCapsule: ...
def __dlpack__(self, /, *, stream: _T_contra | None = None) -> CapsuleType: ...

def from_dlpack(
obj: _SupportsDLPack[None],
x: _SupportsDLPack[None],
/,
*,
device: L["cpu"] | None = ...,
copy: bool | None = ...,
) -> NDArray[Any]: ...
device: L["cpu"] | None = None,
copy: builtins.bool | None = None,
) -> NDArray[number[Any] | np.bool]: ...
11 changes: 6 additions & 5 deletions numpy/typing/tests/data/reveal/ndarray_misc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ function-based counterpart in `../from_numeric.py`.

import operator
import ctypes as ct
from types import ModuleType
from typing import Any, Literal

import numpy as np
import numpy.typing as npt

from typing_extensions import assert_type
from typing_extensions import CapsuleType, assert_type

class SubClass(npt.NDArray[np.object_]): ...

Expand All @@ -30,8 +31,8 @@ AR_V: npt.NDArray[np.void]

ctypes_obj = AR_f8.ctypes

assert_type(AR_f8.__dlpack__(), Any)
assert_type(AR_f8.__dlpack_device__(), tuple[int, Literal[0]])
assert_type(AR_f8.__dlpack__(), CapsuleType)
assert_type(AR_f8.__dlpack_device__(), tuple[Literal[1], Literal[0]])

assert_type(ctypes_obj.data, int)
assert_type(ctypes_obj.shape, ct.Array[np.ctypeslib.c_intp])
Expand Down Expand Up @@ -225,5 +226,5 @@ assert_type(AR_u1.to_device("cpu"), npt.NDArray[np.uint8])
assert_type(AR_c8.to_device("cpu"), npt.NDArray[np.complex64])
assert_type(AR_m.to_device("cpu"), npt.NDArray[np.timedelta64])

assert_type(f8.__array_namespace__(), Any)
assert_type(AR_f8.__array_namespace__(), Any)
assert_type(f8.__array_namespace__(), ModuleType)
assert_type(AR_f8.__array_namespace__(), ModuleType)

0 comments on commit 70fde29

Please sign in to comment.