diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 21d157f022..e949627f59 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -378,6 +378,7 @@ class DeviceType(enum.IntEnum): https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h """ + INVALID_DEVICE = -1 CPU = 1 CUDA = 2 # CPU_PINNED = 3 # noqa: ERA001 @@ -388,11 +389,18 @@ class DeviceType(enum.IntEnum): ROCM = 10 # CUDA_MANAGED = 13 # noqa: ERA001 # ONE_API = 14 # noqa: ERA001 + GPU = ( # TODO whatever makes sense + CUDA + if cp is not None + and not (cp.cuda.runtime.is_hip if hasattr(cp.cuda.runtime, "is_hip") else False) + else ROCM + ) CPUDeviceTyping: TypeAlias = Literal[DeviceType.CPU] CUDADeviceTyping: TypeAlias = Literal[DeviceType.CUDA] ROCMDeviceTyping: TypeAlias = Literal[DeviceType.ROCM] +GPUDeviceTyping: TypeAlias = Literal[DeviceType.GPU] DeviceTypeT = TypeVar( @@ -403,9 +411,7 @@ class DeviceType(enum.IntEnum): ) -CUPY_DEVICE_TYPE = ( - None if not cp else (DeviceType.ROCM if cp.cuda.runtime.is_hip else DeviceType.CUDA) -) +CUPY_DEVICE_TYPE = DeviceType.GPU @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 097b57371b..bb133e33df 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -89,7 +89,7 @@ def is_field_allocation_tool(obj: Any) -> TypeGuard[FieldBufferAllocationUtil]: def is_field_allocation_tool_for( obj: Any, device: core_defs.DeviceTypeT -) -> TypeGuard[FieldBufferAllocationUtil]: +) -> TypeGuard[FieldBufferAllocationUtil[core_defs.DeviceTypeT]]: return is_field_allocator_for(obj, device) or is_field_allocator_factory_for(obj, device)