-
-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Labels
Description
i.e.
from typing import Protocol, final, type_check_only
import numpy as np
import optype.numpy as onp
@final
@type_check_only
class ScalarTypeOf[ArrayT: np.ndarray](Protocol):
"""
This will only accept things of type `T` iff. `T.view() ~ ScalarT` and
`ArrayT.dtype.type ~ type[ScalarT]`.
It relies on the fact that `T.view() -> T` for `T: np.generic` and `T: np.ndarray`.
"""
def view[ScalarT: np.generic](self: ScalarTypeOf[onp.ArrayND[ScalarT]], /) -> ScalarT: ...This allows you to extract the scalar-type of arrays without losing the specific subtype of the array, e.g.:
from typing import overload
import numpy as np
import optype.numpy as onp
@overload
def copy[ArrayT: np.ndarray](
a: ArrayT,
/
dtype: onp.ToDType[ScalarTypeOf[ArrayT]] | None = None,
) -> ArrayT: ...
@overload
def copy[ShapeT: tuple[int, ...], ScalarT: np.generic](
a: onp.Array[ShapeT],
/
dtype: onp.ToDType[ScalarT],
) -> onp.Array[ShapeT, ScalarT]: ...Now, even if you call copy with an e.g. masked array and a dtype matching that of the array, it will still return a masked array. Note that general subtype-preserving behavior isn't possible to express in the current Python type system, as that would require support for higher-kinded-typing.
This idea could also be used to create similar protocols for dtype and shape, i.e. DTypeOf[ArrayT] and ShapeTypeOf[ArrayT] (or ShapeOf[ArrayT]), respectively