Skip to content

optype.numpy.ScalarTypeOf #481

@jorenham

Description

@jorenham

i.e.

from typing import Protocol, final, type_check_only

import numpy as np
import optype.numpy as onp


@final
@type_check_only
class ScalarTypeOf[ArrayT: np.ndarray](Protocol):
    """
    This will only accept things of type `T` iff. `T.view() ~ ScalarT` and
    `ArrayT.dtype.type ~ type[ScalarT]`.
    It relies on the fact that `T.view() -> T` for `T: np.generic` and `T: np.ndarray`.
    """
    def view[ScalarT: np.generic](self: ScalarTypeOf[onp.ArrayND[ScalarT]], /) -> ScalarT: ...

This allows you to extract the scalar-type of arrays without losing the specific subtype of the array, e.g.:

from typing import overload

import numpy as np
import optype.numpy as onp


@overload
def copy[ArrayT: np.ndarray](
    a: ArrayT,
    /
    dtype: onp.ToDType[ScalarTypeOf[ArrayT]] | None = None,
) -> ArrayT: ...
@overload
def copy[ShapeT: tuple[int, ...], ScalarT: np.generic](
    a: onp.Array[ShapeT],
    /
    dtype: onp.ToDType[ScalarT],
) -> onp.Array[ShapeT, ScalarT]: ...

Now, even if you call copy with an e.g. masked array and a dtype matching that of the array, it will still return a masked array. Note that general subtype-preserving behavior isn't possible to express in the current Python type system, as that would require support for higher-kinded-typing.


This idea could also be used to create similar protocols for dtype and shape, i.e. DTypeOf[ArrayT] and ShapeTypeOf[ArrayT] (or ShapeOf[ArrayT]), respectively

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions