Skip to content

Commit 5bd4c99

Browse files
authored
Merge branch 'master' into logical_testing
2 parents 7e46cca + b7e42bd commit 5bd4c99

File tree

10 files changed

+189
-34
lines changed

10 files changed

+189
-34
lines changed

arrayfire_wrapper/_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__all__ = ["BackendType"]
1+
__all__ = ["BackendType", "get_backend", "set_backend"]
22

33
import ctypes
44
import enum

arrayfire_wrapper/lib/create_and_modify_array/create_array/pad.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ def pad(arr: AFArray, begin_shape: tuple[int, ...], end_shape: tuple[int, ...],
1717
ctypes.pointer(out),
1818
arr,
1919
4,
20-
begin_c_shape.c_array,
20+
ctypes.pointer(begin_c_shape.c_array),
2121
4,
22-
end_c_shape.c_array,
22+
ctypes.pointer(end_c_shape.c_array),
2323
border_type.value,
2424
)
2525
return out

arrayfire_wrapper/lib/create_and_modify_array/manage_device.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ def alloc_host(num_bytes: int, /) -> int:
1111
1212
Allocate a buffer on the host with specified number of bytes.
1313
"""
14+
# TODO
15+
# Avoid using AFArray and use ctypes.c_void_p to avoid misunderstanding 'coz its not actually an array
1416
out = AFArray.create_null_pointer()
1517
call_from_clib(alloc_host.__name__, ctypes.pointer(out), CDimT(num_bytes))
1618
return out.value # type: ignore[return-value]

arrayfire_wrapper/lib/create_and_modify_array/move_and_reorder.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,31 @@ def replace_scalar(lhs: AFArray, cond_arr: AFArray, rhs: int | float, /) -> None
7878
call_from_clib(replace.__name__, lhs, cond_arr, ctypes.c_double(rhs))
7979

8080

81-
def select(lhs: AFArray, cond_arr: AFArray, rhs: AFArray, /) -> None:
81+
def select(lhs: AFArray, cond_arr: AFArray, rhs: AFArray, /) -> AFArray:
8282
"""
8383
source: https://arrayfire.org/docs/group__data__func__select.htm#gac4af16e31ddd5ddcf09b670f676fd093
8484
"""
85-
call_from_clib(select.__name__, cond_arr, lhs, rhs)
85+
out = AFArray.create_null_pointer()
86+
call_from_clib(select.__name__, ctypes.pointer(out), cond_arr, lhs, rhs)
87+
return out
8688

8789

88-
def select_scalar_l(lhs: int | float, cond_arr: AFArray, rhs: AFArray, /) -> None:
90+
def select_scalar_l(lhs: int | float, cond_arr: AFArray, rhs: AFArray, /) -> AFArray:
8991
"""
9092
source: https://arrayfire.org/docs/group__data__func__select.htm#gac4af16e31ddd5ddcf09b670f676fd093
9193
"""
92-
call_from_clib(select_scalar_l.__name__, cond_arr, ctypes.c_double(lhs), rhs)
94+
out = AFArray.create_null_pointer()
95+
call_from_clib(select_scalar_l.__name__, ctypes.pointer(out), cond_arr, ctypes.c_double(lhs), rhs)
96+
return out
9397

9498

95-
def select_scalar_r(lhs: AFArray, cond_arr: AFArray, rhs: int | float, /) -> None:
99+
def select_scalar_r(lhs: AFArray, cond_arr: AFArray, rhs: int | float, /) -> AFArray:
96100
"""
97101
source: https://arrayfire.org/docs/group__data__func__select.htm#gac4af16e31ddd5ddcf09b670f676fd093
98102
"""
99-
call_from_clib(select_scalar_l.__name__, cond_arr, lhs, ctypes.c_double(rhs))
103+
out = AFArray.create_null_pointer()
104+
call_from_clib(select_scalar_l.__name__, ctypes.pointer(out), cond_arr, lhs, ctypes.c_double(rhs))
105+
return out
100106

101107

102108
def shift(arr: AFArray, /, d0: int, d1: int = 0, d2: int = 0, d3: int = 0) -> AFArray:

arrayfire_wrapper/lib/vector_algorithms/reduction_operations.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,7 @@ def max_ragged(arr: AFArray, ragged_len: AFArray, dim: int, /) -> tuple[AFArray,
144144
"""
145145
out_values = AFArray.create_null_pointer()
146146
out_idx = AFArray.create_null_pointer()
147-
call_from_clib(
148-
max_ragged.__name__, ctypes.pointer(out_values), ctypes.pointer(out_idx), arr, ragged_len, ctypes.c_int(dim)
149-
)
147+
call_from_clib(max_ragged.__name__, ctypes.pointer(out_values), ctypes.pointer(out_idx), arr, ragged_len, dim)
150148
return (out_values, out_idx)
151149

152150

@@ -156,9 +154,7 @@ def max_by_key(keys: AFArray, values: AFArray, dim: int, /) -> tuple[AFArray, AF
156154
"""
157155
out_keys = AFArray.create_null_pointer()
158156
out_values = AFArray.create_null_pointer()
159-
call_from_clib(
160-
max_by_key.__name__, ctypes.pointer(out_keys), ctypes.pointer(out_values), keys, values, ctypes.c_int(dim)
161-
)
157+
call_from_clib(max_by_key.__name__, ctypes.pointer(out_keys), ctypes.pointer(out_values), keys, values, dim)
162158
return (out_keys, out_values)
163159

164160

arrayfire_wrapper/lib/vector_algorithms/set_operations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ def set_unique(arr: AFArray, is_sorted: bool, /) -> AFArray:
2727
source: https://arrayfire.org/docs/group__set__func__unique.htm#ga6afa1de48cbbc4b2df530c2530087943
2828
"""
2929
out = AFArray.create_null_pointer()
30-
call_from_clib(set_intersect.__name__, ctypes.pointer(out), ctypes.c_bool(is_sorted))
30+
call_from_clib(set_unique.__name__, ctypes.pointer(out), arr, ctypes.c_bool(is_sorted))
3131
return out

arrayfire_wrapper/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22

33
_MAJOR = "0"
4-
_MINOR = "6"
4+
_MINOR = "7"
55
# On main and in a nightly release the patch should be one ahead of the last
66
# released build.
77
_PATCH = "0"

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ build-backend = "scikit_build_core.build"
2929

3030
[project]
3131
name = "arrayfire-binary-python-wrapper"
32-
version = "0.6.0+AF3.9.0"
32+
version = "0.7.0+AF3.9.0"
3333
requires-python = ">=3.10"
3434
authors = [
3535
{ name = "ArrayFire", email = "[email protected]"},

tests/test_arithmetic.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import random
2+
3+
import pytest
4+
5+
import arrayfire_wrapper.dtypes as dtype
6+
import arrayfire_wrapper.lib as wrapper
7+
from tests.utility_functions import check_type_supported, get_all_types
8+
9+
10+
@pytest.mark.parametrize(
11+
"shape",
12+
[
13+
(),
14+
(random.randint(1, 10),),
15+
(random.randint(1, 10), random.randint(1, 10)),
16+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
17+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
18+
],
19+
)
20+
def test_add_shapes(shape: tuple) -> None:
21+
"""Test addition operation between two arrays of the same shape"""
22+
lhs = wrapper.randu(shape, dtype.f16)
23+
rhs = wrapper.randu(shape, dtype.f16)
24+
25+
result = wrapper.add(lhs, rhs)
26+
27+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203, W291
28+
29+
30+
def test_add_different_shapes() -> None:
31+
"""Test if addition handles arrays of different shapes"""
32+
with pytest.raises(RuntimeError):
33+
lhs_shape = (2, 3)
34+
rhs_shape = (3, 2)
35+
dtypes = dtype.f16
36+
lhs = wrapper.randu(lhs_shape, dtypes)
37+
rhs = wrapper.randu(rhs_shape, dtypes)
38+
39+
wrapper.add(lhs, rhs)
40+
41+
42+
@pytest.mark.parametrize("dtype_name", get_all_types())
43+
def test_add_supported_dtypes(dtype_name: dtype.Dtype) -> None:
44+
"""Test addition operation across all supported data types."""
45+
check_type_supported(dtype_name)
46+
shape = (5, 5) # Using a common shape for simplicity
47+
lhs = wrapper.randu(shape, dtype_name)
48+
rhs = wrapper.randu(shape, dtype_name)
49+
result = wrapper.add(lhs, rhs)
50+
assert dtype.c_api_value_to_dtype(wrapper.get_type(result)) == dtype_name, f"Failed for dtype: {dtype_name}"
51+
52+
53+
@pytest.mark.parametrize(
54+
"invdtypes",
55+
[
56+
dtype.c64,
57+
dtype.f64,
58+
],
59+
)
60+
def test_add_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
61+
"""Test addition operation across all supported data types."""
62+
with pytest.raises(RuntimeError):
63+
shape = (5, 5)
64+
lhs = wrapper.randu(shape, invdtypes)
65+
rhs = wrapper.randu(shape, invdtypes)
66+
result = wrapper.add(lhs, rhs)
67+
assert dtype.c_api_value_to_dtype(wrapper.get_type(result)) == invdtypes, f"Didn't Fail for Dtype: {invdtypes}"
68+
69+
70+
def test_add_zero_sized_arrays() -> None:
71+
"""Test addition with arrays where at least one array has zero size."""
72+
with pytest.raises(RuntimeError):
73+
zero_shape = (0, 5)
74+
normal_shape = (5, 5)
75+
zero_array = wrapper.randu(zero_shape, dtype.f32)
76+
normal_array = wrapper.randu(normal_shape, dtype.f32)
77+
78+
# Test addition when lhs is zero-sized
79+
result_lhs_zero = wrapper.add(zero_array, normal_array)
80+
assert wrapper.get_dims(result_lhs_zero) == zero_shape
81+
82+
83+
@pytest.mark.parametrize(
84+
"shape",
85+
[
86+
(),
87+
(random.randint(1, 10),),
88+
(random.randint(1, 10), random.randint(1, 10)),
89+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
90+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
91+
],
92+
)
93+
def test_subtract_shapes(shape: tuple) -> None:
94+
"""Test subtraction operation between two arrays of the same shape"""
95+
lhs = wrapper.randu(shape, dtype.f16)
96+
rhs = wrapper.randu(shape, dtype.f16)
97+
98+
result = wrapper.sub(lhs, rhs)
99+
100+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203, W291
101+
102+
103+
def test_subtract_different_shapes() -> None:
104+
"""Test if subtraction handles arrays of different shapes"""
105+
with pytest.raises(RuntimeError):
106+
lhs_shape = (2, 3)
107+
rhs_shape = (3, 2)
108+
dtypes = dtype.f16
109+
lhs = wrapper.randu(lhs_shape, dtypes)
110+
rhs = wrapper.randu(rhs_shape, dtypes)
111+
112+
wrapper.sub(lhs, rhs)
113+
114+
115+
@pytest.mark.parametrize("dtype_name", get_all_types())
116+
def test_subtract_supported_dtypes(dtype_name: dtype.Dtype) -> None:
117+
"""Test subtraction operation across all supported data types."""
118+
check_type_supported(dtype_name)
119+
shape = (5, 5)
120+
lhs = wrapper.randu(shape, dtype_name)
121+
rhs = wrapper.randu(shape, dtype_name)
122+
result = wrapper.sub(lhs, rhs)
123+
assert dtype.c_api_value_to_dtype(wrapper.get_type(result)) == dtype_name, f"Failed for dtype: {dtype_name}"
124+
125+
126+
@pytest.mark.parametrize(
127+
"invdtypes",
128+
[
129+
dtype.c64,
130+
dtype.f64,
131+
],
132+
)
133+
def test_subtract_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
134+
"""Test subtraction operation for unsupported data types."""
135+
with pytest.raises(RuntimeError):
136+
shape = (5, 5)
137+
lhs = wrapper.randu(shape, invdtypes)
138+
rhs = wrapper.randu(shape, invdtypes)
139+
result = wrapper.sub(lhs, rhs)
140+
assert result == invdtypes, f"Didn't Fail for Dtype: {invdtypes}"
141+
142+
143+
def test_subtract_zero_sized_arrays() -> None:
144+
"""Test subtraction with arrays where at least one array has zero size."""
145+
with pytest.raises(RuntimeError):
146+
zero_shape = (0, 5)
147+
normal_shape = (5, 5)
148+
zero_array = wrapper.randu(zero_shape, dtype.f32)
149+
normal_array = wrapper.randu(normal_shape, dtype.f32)
150+
151+
result_lhs_zero = wrapper.sub(zero_array, normal_array)
152+
assert wrapper.get_dims(result_lhs_zero) == zero_shape

tests/test_trig.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
import arrayfire_wrapper.dtypes as dtype
66
import arrayfire_wrapper.lib as wrapper
7-
8-
from . import utility_functions as util
7+
from tests.utility_functions import check_type_supported, get_all_types, get_float_types
98

109

1110
@pytest.mark.parametrize(
@@ -18,10 +17,10 @@
1817
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
1918
],
2019
)
21-
@pytest.mark.parametrize("dtype_name", util.get_all_types())
20+
@pytest.mark.parametrize("dtype_name", get_all_types())
2221
def test_asin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
2322
"""Test inverse sine operation across all supported data types."""
24-
util.check_type_supported(dtype_name)
23+
check_type_supported(dtype_name)
2524
values = wrapper.randu(shape, dtype_name)
2625
result = wrapper.asin(values)
2726
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -37,10 +36,10 @@ def test_asin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
3736
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
3837
],
3938
)
40-
@pytest.mark.parametrize("dtype_name", util.get_all_types())
39+
@pytest.mark.parametrize("dtype_name", get_all_types())
4140
def test_acos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
4241
"""Test inverse cosine operation across all supported data types."""
43-
util.check_type_supported(dtype_name)
42+
check_type_supported(dtype_name)
4443
values = wrapper.randu(shape, dtype_name)
4544
result = wrapper.acos(values)
4645
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -56,10 +55,10 @@ def test_acos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
5655
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
5756
],
5857
)
59-
@pytest.mark.parametrize("dtype_name", util.get_all_types())
58+
@pytest.mark.parametrize("dtype_name", get_all_types())
6059
def test_atan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
6160
"""Test inverse tan operation across all supported data types."""
62-
util.check_type_supported(dtype_name)
61+
check_type_supported(dtype_name)
6362
values = wrapper.randu(shape, dtype_name)
6463
result = wrapper.atan(values)
6564
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -75,10 +74,10 @@ def test_atan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
7574
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
7675
],
7776
)
78-
@pytest.mark.parametrize("dtype_name", util.get_float_types())
77+
@pytest.mark.parametrize("dtype_name", get_float_types())
7978
def test_atan2_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
8079
"""Test inverse tan operation across all supported data types."""
81-
util.check_type_supported(dtype_name)
80+
check_type_supported(dtype_name)
8281
if dtype_name == dtype.f16:
8382
pytest.skip()
8483
lhs = wrapper.randu(shape, dtype_name)
@@ -110,10 +109,10 @@ def test_atan2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
110109
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
111110
],
112111
)
113-
@pytest.mark.parametrize("dtype_name", util.get_all_types())
112+
@pytest.mark.parametrize("dtype_name", get_all_types())
114113
def test_cos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
115114
"""Test cosine operation across all supported data types."""
116-
util.check_type_supported(dtype_name)
115+
check_type_supported(dtype_name)
117116
values = wrapper.randu(shape, dtype_name)
118117
result = wrapper.cos(values)
119118
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -129,10 +128,10 @@ def test_cos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
129128
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
130129
],
131130
)
132-
@pytest.mark.parametrize("dtype_name", util.get_all_types())
131+
@pytest.mark.parametrize("dtype_name", get_all_types())
133132
def test_sin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
134133
"""Test sin operation across all supported data types."""
135-
util.check_type_supported(dtype_name)
134+
check_type_supported(dtype_name)
136135
values = wrapper.randu(shape, dtype_name)
137136
result = wrapper.sin(values)
138137
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -148,10 +147,10 @@ def test_sin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
148147
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
149148
],
150149
)
151-
@pytest.mark.parametrize("dtype_name", util.get_all_types())
150+
@pytest.mark.parametrize("dtype_name", get_all_types())
152151
def test_tan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
153152
"""Test tan operation across all supported data types."""
154-
util.check_type_supported(dtype_name)
153+
check_type_supported(dtype_name)
155154
values = wrapper.randu(shape, dtype_name)
156155
result = wrapper.tan(values)
157156
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa

0 commit comments

Comments
 (0)