Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 27 additions & 43 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ name: CI
permissions: read-all

on:
push:
branches:
- master
pull_request:
push:
branches: [master]
workflow_dispatch:

concurrency:
Expand All @@ -23,7 +22,8 @@ jobs:
steps:
- uses: actions/checkout@v4

- uses: scientific-python/[email protected]
- name: repo review
uses: scientific-python/[email protected]
with:
plugins: sp-repo-review

Expand All @@ -36,14 +36,12 @@ jobs:
- name: typos
uses: crate-ci/typos@master

- name: install uv
uses: astral-sh/setup-uv@v6
- uses: astral-sh/setup-uv@v6
with:
python-version: "3.13"

- name: ruff check
run: uv run ruff check --output-format=github

- name: ruff format
run: uv run ruff format --check

Expand All @@ -52,63 +50,49 @@ jobs:

typecheck:
timeout-minutes: 5
runs-on: ubuntu-latest

strategy:
fail-fast: false
matrix:
np: ["1.25", "1.26", "2.0", "2.1", "2.2"]

os: [ubuntu-latest]
np: ["1.25", "1.26", "2.0", "2.1", "2.2", "2.3"]
py: ["3.11"]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4

- name: setup uv
uses: astral-sh/setup-uv@v6
- uses: astral-sh/setup-uv@v6
with:
activate-environment: true
python-version: "3.11"
- run: uv sync

- name: Install numpy ${{ matrix.np }}
run: uv pip install "numpy==${{ matrix.np }}.*"

python-version: ${{ matrix.py }}
- name: basedpyright
run: >
uv run --no-sync
uv run --with="numpy==${{ matrix.np }}.*"
basedpyright -p scripts/config/bpr-np-${{ matrix.np }}.json

- name: mypy
run: uv run --no-sync scripts/my.py
run: >
uv run --with="numpy==${{ matrix.np }}.*"
scripts/my.py

test:
timeout-minutes: 5
runs-on: ${{ matrix.os }}

strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest]
py: ["3.11", "3.12", "3.13"]
include:
py: ["3.11", "3.13"]
np: ["1.25", "2.3"]
exclude:
- os: ubuntu-latest
py: "3.11"
py: "3.13"
np: "1.25"

- os: windows-latest
py: "3.13"
np: "1.25"
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4

- name: setup uv
uses: astral-sh/setup-uv@v6
- uses: astral-sh/setup-uv@v6
with:
activate-environment: true
python-version: ${{ matrix.py }}

- name: uv sync
run: uv sync

- name: Install old numpy
if: ${{ matrix.np == '1.25' }}
run: uv pip install "numpy==${{ matrix.np }}.*"

- name: pytest
run: uv run --no-sync pytest
run: >
uv run --with="numpy==${{ matrix.np }}.*"
pytest
48 changes: 20 additions & 28 deletions optype/numpy/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,38 +107,30 @@ def __dir__() -> list[str]:
] = np.ma.MaskedArray[NDT, np.dtype[SCT]]
"""

if _x.NP21:
# numpy >= 2.1: shape is covariant

@runtime_checkable
@set_module("optype.numpy")
class CanArray(Protocol[_NDT_co, _DTT_co]):
def __array__(self, /) -> np.ndarray[_NDT_co, _DTT_co]: ...
# NOTE: Before NumPy 2.1 the shape type parameter of `numpy.ndarray` was invariant. This
# lead to various issues, so we ignore that, and suppress two pyright errors that are
# reported when `numpy<2.1` is installed (inline `# pyright: ignore` won't work).

@runtime_checkable
@set_module("optype.numpy")
class CanArrayND(Protocol[_SCT_co, _NDT_co]):
"""
Similar to `onp.CanArray`, but must be sized (i.e. excludes scalars), and is
parameterized by only the scalar type (instead of the shape and dtype).
"""
# pyright: reportInvalidTypeVarUse=false

def __len__(self, /) -> int: ...
def __array__(self, /) -> np.ndarray[_NDT_co, np.dtype[_SCT_co]]: ...

else:
# numpy < 2.1: shape is invariant

@runtime_checkable
@set_module("optype.numpy")
class CanArray(Protocol[_NDT, _DTT_co]):
def __array__(self, /) -> np.ndarray[_NDT, _DTT_co]: ...

@runtime_checkable
@set_module("optype.numpy")
class CanArrayND(Protocol[_SCT_co, _NDT]):
def __len__(self, /) -> int: ...
def __array__(self, /) -> np.ndarray[_NDT, np.dtype[_SCT_co]]: ...
@runtime_checkable
@set_module("optype.numpy")
class CanArray(Protocol[_NDT_co, _DTT_co]):
def __array__(self, /) -> np.ndarray[_NDT_co, _DTT_co]: ...


@runtime_checkable
@set_module("optype.numpy")
class CanArrayND(Protocol[_SCT_co, _NDT_co]):
"""
Similar to `onp.CanArray`, but must be sized (i.e. excludes scalars), and is
parameterized by only the scalar type (instead of the shape and dtype).
"""

def __len__(self, /) -> int: ...
def __array__(self, /) -> np.ndarray[_NDT_co, np.dtype[_SCT_co]]: ...


Array0D = TypeAliasType(
Expand Down
49 changes: 17 additions & 32 deletions optype/numpy/_ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ class UFunc(Protocol[_FT_co, _NInT_co, _NoutT_co, _SigT_co, _IdT_co]):
This also includes gufunc's (generalized universion functions), which
have a specified `signature`, and aren't necessarily element-wise
functions (which "regular" ufuncs are).
At the moment (`numpy>=2.0,<2.2`), the only GUFuncs within numpy are
`matmul`, and `vecdot`.
At the moment (`numpy>=2.2,<2.4`), the only `GUFuncs` in the public numpy API
are `matmul`, `matvec`, `vecdot`, and `vecmat`, and all four have `nin == 2`
and `nout == 1`.
"""

@property
Expand Down Expand Up @@ -134,11 +135,12 @@ class UFunc(Protocol[_FT_co, _NInT_co, _NoutT_co, _SigT0_co, _IdT_co]):
A generic interface for `numpy.ufunc` "universal function" instances,
e.g. `numpy.exp`, `numpy.add`, `numpy.frexp`, `numpy.divmod`.

This also includes gufunc's (generalized universion functions), which
have a specified `signature`, and aren't necessarily element-wise
functions (which "regular" ufuncs are).
At the moment (`numpy>=2.0,<2.2`), the only GUFuncs within numpy are
`matmul`, and `vecdot`.
This also includes gufunc's (generalized universion functions), which have a
specified `signature`, and aren't necessarily element-wise functions
(which "regular" ufuncs are).
At the moment (`numpy>=2.2,<2.4`), the only `GUFuncs` in the public numpy API
are `matmul`, `matvec`, `vecdot`, and `vecmat`, and all four have `nin == 2`
and `nout == 1`.
"""

@property
Expand Down Expand Up @@ -189,31 +191,14 @@ class CanArrayUFunc(Protocol[_UFT_contra, _T_co]):
- https://numpy.org/devdocs/reference/arrays.classes.html
"""

# NOTE: Mypy doesn't understand the Liskov substitution principle when
# positional-only arguments are involved; so `ufunc` and `method` can't
# be made positional-only.

if _x.NP20:

def __array_ufunc__(
self,
/,
ufunc: _UFT_contra,
method: L[_MethodCommon, "at"],
*args: Any,
**kwargs: Any,
) -> _T_co: ...

else:

def __array_ufunc__(
self,
/,
ufunc: _UFT_contra,
method: L[_MethodCommon, "inner"],
*args: Any,
**kwargs: Any,
) -> _T_co: ...
def __array_ufunc__(
self,
ufunc: _UFT_contra,
method: L[_MethodCommon],
/,
*args: Any,
**kwargs: Any,
) -> _T_co: ...


_FT_contra = TypeVar("_FT_contra", bound=_AnyFunc, default=_AnyFunc, contravariant=True)
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = ["typing-extensions>=4.10; python_version<'3.13'"]
[dependency-groups]
extra = ["optype[numpy]"]
lint = [
"ruff>=0.11.12",
"ruff>=0.11.13",
"sp-repo-review[cli]>=2025.5.2",
]
type = [
Expand All @@ -46,7 +46,7 @@ type = [
]
test = [
"beartype>=0.21.0",
"pytest>=8.3.5",
"pytest>=8.4.0",
]
dev = [
{include-group = "extra"},
Expand Down Expand Up @@ -129,7 +129,7 @@ reportUnusedVariable = false # dupe of F841
NP20 = true
NP21 = true
NP22 = true
NP23 = false
NP23 = true


[tool.pytest.ini_options]
Expand Down
Loading