Skip to content

Commit f84366d

Browse files
committed
added opencl functionality and tests within wrapper
1 parent 7f97198 commit f84366d

File tree

2 files changed

+126
-1
lines changed

2 files changed

+126
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -1 +1,100 @@
1-
# TODO
1+
import ctypes
2+
from enum import Enum
3+
4+
from arrayfire_wrapper.lib._utility import call_from_clib
5+
6+
7+
class DeviceType(Enum):
8+
CPU = 2
9+
GPU = 4
10+
ACC = 8
11+
UNKNOWN = -1
12+
13+
14+
class PlatformType(Enum):
15+
AMD = 0
16+
APPLE = 1
17+
INTEL = 2
18+
NVIDIA = 3
19+
BEIGNET = 4
20+
POCL = 5
21+
UNKNOWN = -1
22+
23+
24+
def get_context(retain: bool = False) -> int:
25+
"""
26+
source: https://arrayfire.org/docs/group__opencl__mat.htm#gad42de383f405b3e38d6eb669c0cbe2e3
27+
"""
28+
out = ctypes.c_void_p()
29+
call_from_clib(get_context.__name__, ctypes.pointer(out), retain, clib_prefix="afcl")
30+
return out.value # type: ignore[return-value]
31+
32+
33+
def get_queue(retain: bool = False) -> int:
34+
"""
35+
source: https://arrayfire.org/docs/group__opencl__mat.htm#gab1701ef4f2b68429eb31c1e21c88d0bc
36+
"""
37+
out = ctypes.c_void_p()
38+
call_from_clib(get_queue.__name__, ctypes.pointer(out), retain, clib_prefix="afcl")
39+
return out.value # type: ignore[return-value]
40+
41+
42+
def get_device_id() -> int:
43+
"""
44+
source: https://arrayfire.org/docs/group__opencl__mat.htm#gaf7258055284e65a8647a49c3f3b9feee
45+
"""
46+
out = ctypes.c_void_p()
47+
call_from_clib(get_device_id.__name__, ctypes.pointer(out), clib_prefix="afcl")
48+
return out.value # type: ignore[return-value]
49+
50+
51+
def set_device_id(idx: int) -> None:
52+
"""
53+
source: https://arrayfire.org/docs/group__opencl__mat.htm#ga600361a20ceac2a65590b67fc0366314
54+
"""
55+
call_from_clib(set_device_id.__name__, ctypes.c_int64(idx), clib_prefix="afcl")
56+
return None
57+
58+
59+
def add_device_context(dev: int, ctx: int, que: int) -> None:
60+
"""
61+
source: https://arrayfire.org/docs/group__opencl__mat.htm#ga49f596a4041fb757f1f5a75999cf8858
62+
"""
63+
call_from_clib(
64+
add_device_context.__name__, ctypes.c_int64(dev), ctypes.c_int64(ctx), ctypes.c_int64(que), clib_prefix="afcl"
65+
)
66+
return None
67+
68+
69+
def set_device_context(dev: int, ctx: int) -> None:
70+
"""
71+
source: https://arrayfire.org/docs/group__opencl__mat.htm#ga975661f2b06dddb125c5d1757160b02c
72+
"""
73+
call_from_clib(set_device_context.__name__, ctypes.c_int64(dev), ctypes.c_int64(ctx), clib_prefix="afcl")
74+
return None
75+
76+
77+
def delete_device_context(dev: int, ctx: int) -> None:
78+
"""
79+
source: https://arrayfire.org/docs/group__opencl__mat.htm#ga1a56dcf05099d6ac0a3b7701f7cb23f8
80+
"""
81+
call_from_clib(delete_device_context.__name__, ctypes.c_int64(dev), ctypes.c_int64(ctx), clib_prefix="afcl")
82+
return None
83+
84+
85+
def get_device_type() -> DeviceType:
86+
"""
87+
source: https://arrayfire.org/docs/group__opencl__mat.htm#ga5e360e0fe0eb55d0046191bc3fd6f81d
88+
"""
89+
res = ctypes.c_void_p()
90+
call_from_clib(get_device_type.__name__, ctypes.pointer(res), clib_prefix="afcl")
91+
return DeviceType(res.value)
92+
93+
94+
def get_platform() -> PlatformType:
95+
"""
96+
source: https://arrayfire.org/docs/group__opencl__mat.htm#ga5e360e0fe0eb55d0046191bc3fd6f81d&gsc.tab=0
97+
"""
98+
res = ctypes.c_void_p()
99+
call_from_clib(get_platform.__name__, ctypes.pointer(res), clib_prefix="afcl")
100+
return PlatformType(res.value)

tests/test_opencl.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import arrayfire_wrapper.lib.interface_functions.opencl as cl
2+
3+
4+
def test_get_context_type():
5+
assert isinstance(cl.get_context(), int)
6+
7+
8+
def test_get_queue_type():
9+
assert isinstance(cl.get_queue(), int)
10+
11+
12+
def test_get_device_id():
13+
assert isinstance(cl.get_device_id(), int)
14+
15+
16+
def test_set_device_id():
17+
cl.set_device_id(0)
18+
assert cl.get_device_id() == 0
19+
20+
21+
def test_get_device_type():
22+
assert cl.get_device_type() == cl.DeviceType.GPU # change according to device
23+
24+
25+
def test_get_platform():
26+
assert cl.get_platform() == cl.PlatformType.INTEL # change according to platform

0 commit comments

Comments
 (0)