Skip to content

Commit 7f97198

Browse files
committed
"update local master branch"
Merge branch 'master' of https://github.com/arrayfire/arrayfire-binary-python-wrapper
2 parents afb263e + 983500e commit 7f97198

12 files changed

+2733
-57
lines changed

arrayfire_wrapper/_backend.py

+8-44
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from pathlib import Path
1212
from typing import Iterator
1313

14-
from arrayfire_wrapper.defines import AFArray
15-
1614
from .version import ARRAYFIRE_VER_MAJOR
1715

1816
VERBOSE_LOADS = os.environ.get("AF_VERBOSE_LOADS", "") == "1"
@@ -37,7 +35,7 @@ def is_cygwin(cls, name: str) -> bool:
3735
class _BackendPathConfig:
3836
lib_prefix: str
3937
lib_postfix: str
40-
af_path: Path
38+
af_path: Path | None
4139
af_is_user_path: bool
4240
cuda_found: bool
4341

@@ -175,7 +173,7 @@ def __iter__(self) -> Iterator:
175173

176174

177175
class Backend:
178-
_backend_type: BackendType
176+
_backend_type: BackendType | None
179177
_clibs: dict[BackendType, ctypes.CDLL]
180178

181179
def __init__(self) -> None:
@@ -297,51 +295,17 @@ def _find_nvrtc_builtins_lib_name(self, search_path: Path) -> str | None:
297295
return f.name
298296
return None
299297

300-
# unified backend functions
301-
def get_active_backend(self) -> str:
302-
if self._backend_type == BackendType.unified:
303-
from arrayfire_wrapper.lib.unified_api_functions import get_active_backend as unified_get_active_backend
304-
305-
return unified_get_active_backend()
306-
raise RuntimeError("Using unified function on non-unified backend")
307-
308-
def get_available_backends(self) -> list[int]:
309-
if self._backend_type == BackendType.unified:
310-
from arrayfire_wrapper.lib.unified_api_functions import (
311-
get_available_backends as unified_get_available_backends,
312-
)
313-
314-
return unified_get_available_backends()
315-
raise RuntimeError("Using unified function on non-unified backend")
316-
317-
def get_backend_count(self) -> int:
318-
if self._backend_type == BackendType.unified:
319-
from arrayfire_wrapper.lib.unified_api_functions import get_backend_count as unified_get_backend_count
320-
321-
return unified_get_backend_count()
322-
raise RuntimeError("Using unified function on non-unified backend")
323-
324-
def get_backend_id(self, arr: AFArray, /) -> int:
325-
if self._backend_type == BackendType.unified:
326-
from arrayfire_wrapper.lib.unified_api_functions import get_backend_id as unified_get_backend_id
327-
328-
return unified_get_backend_id(arr)
329-
raise RuntimeError("Using unified function on non-unified backend")
330-
331-
def get_device_id(self, arr: AFArray, /) -> int:
332-
if self._backend_type == BackendType.unified:
333-
from arrayfire_wrapper.lib.unified_api_functions import get_device_id as unified_get_device_id
334-
335-
return unified_get_device_id(arr)
336-
raise RuntimeError("Using unified function on non-unified backend")
337-
338298
@property
339299
def backend_type(self) -> BackendType:
340-
return self._backend_type
300+
if self._backend_type:
301+
return self._backend_type
302+
raise RuntimeError("No valid _backend_type")
341303

342304
@property
343305
def clib(self) -> ctypes.CDLL:
344-
return self._clibs[self._backend_type]
306+
if self._backend_type:
307+
return self._clibs[self._backend_type]
308+
raise RuntimeError("No valid _backend_type")
345309

346310

347311
# Initialize the backend

arrayfire_wrapper/lib/mathematical_functions/numeric_functions.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
import ctypes
22

3+
import arrayfire_wrapper.dtypes as dtype
4+
import arrayfire_wrapper.lib as wrapper
35
from arrayfire_wrapper.defines import AFArray
4-
from arrayfire_wrapper.dtypes import float32
56
from arrayfire_wrapper.lib._utility import binary_op, call_from_clib, unary_op
67
from arrayfire_wrapper.lib.create_and_modify_array.create_array import create_constant_array
78
from arrayfire_wrapper.lib.mathematical_functions.arithmetic_operations import sub
89

910

10-
import arrayfire_wrapper.dtypes as dtype
11-
import arrayfire_wrapper.lib as wrapper
12-
13-
1411
def abs_(arr: AFArray, /) -> AFArray:
1512
"""
1613
source: https://arrayfire.org/docs/group__arith__func__abs.htm#ga7e8b3c848e6cda3d1f3b0c8b2b4c3f8f

scripts/build_package_without_binaries.sh

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
#!/bin/bash
22

3-
# Run the Python script and capture the output and error
3+
# Run the Python script and capture the output or error
44
output=$(python -m build 2>&1)
55

6-
# Define the expected error message
7-
expected_error="Could not load any ArrayFire libraries."
6+
# Define the expected output message
7+
expected_output="Successfully built"
88

9-
# Check if the output contains the expected error message
10-
if echo "$output" | grep -q "$expected_error"; then
11-
echo "Expected error received."
12-
exit 0 # Exit with success as the error is expected
9+
# Check if the output contains the expected output message
10+
if echo "$output" | grep -q "$expected_output"; then
11+
echo "Expected output received."
12+
exit 0 # Exit with success as the output is expected
1313
else
1414
echo "Unexpected output: $output"
1515
exit 1 # Exit with failure as the output was not expected

tests/test_bitshift.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import pytest
2+
3+
import arrayfire_wrapper.dtypes as dtype
4+
import arrayfire_wrapper.lib as wrapper
5+
from arrayfire_wrapper.lib.create_and_modify_array.helper_functions import array_to_string
6+
from tests.utility_functions import check_type_supported, get_real_types
7+
8+
9+
@pytest.mark.parametrize("dtype_name", get_real_types())
10+
def test_bitshiftl_dtypes(dtype_name: dtype.Dtype) -> None:
11+
"""Test bit shift operation across all supported data types."""
12+
check_type_supported(dtype_name)
13+
if dtype_name == dtype.f16 or dtype_name == dtype.f32:
14+
pytest.skip()
15+
shape = (5, 5)
16+
values = wrapper.randu(shape, dtype_name)
17+
bits_to_shift = wrapper.constant(1, shape, dtype_name)
18+
19+
result = wrapper.bitshiftl(values, bits_to_shift)
20+
21+
assert dtype.c_api_value_to_dtype(wrapper.get_type(result)) == dtype_name, f"Failed for dtype: {dtype_name}"
22+
23+
24+
@pytest.mark.parametrize(
25+
"invdtypes",
26+
[
27+
dtype.c32,
28+
dtype.f64,
29+
],
30+
)
31+
def test_bitshiftl_supported_dtypes(invdtypes: dtype.Dtype) -> None:
32+
"""Test bitshift operations for unsupported integer data types."""
33+
shape = (5, 5)
34+
with pytest.raises(RuntimeError):
35+
value = wrapper.randu(shape, invdtypes)
36+
bits_to_shift = wrapper.constant(1, shape, invdtypes)
37+
38+
result = wrapper.bitshiftl(value, bits_to_shift)
39+
assert dtype.c_api_value_to_dtype(wrapper.get_type(result)) == invdtypes, f"Failed for dtype: {invdtypes}"
40+
41+
42+
@pytest.mark.parametrize("input_size", [8, 10, 12])
43+
def test_bitshiftl_varying_input_size(input_size: int) -> None:
44+
"""Test bitshift left operation with varying input sizes"""
45+
shape = (input_size, input_size)
46+
value = wrapper.randu(shape, dtype.int16)
47+
shift_amount = wrapper.constant(1, shape, dtype.int16) # Fixed shift amount for simplicity
48+
49+
result = wrapper.bitshiftl(value, shift_amount)
50+
51+
assert (wrapper.get_dims(result)[0], wrapper.get_dims(result)[1]) == shape
52+
53+
54+
@pytest.mark.parametrize(
55+
"shape",
56+
[
57+
(10,),
58+
(5, 5),
59+
(2, 3, 4),
60+
],
61+
)
62+
def test_bitshiftl_varying_shapes(shape: tuple) -> None:
63+
"""Test left bit shifting with arrays of varying shapes."""
64+
values = wrapper.randu(shape, dtype.int16)
65+
bits_to_shift = wrapper.constant(1, shape, dtype.int16)
66+
67+
result = wrapper.bitshiftl(values, bits_to_shift)
68+
69+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa
70+
71+
72+
@pytest.mark.parametrize("shift_amount", [-1, 0, 2, 30])
73+
def test_bitshift_left_varying_shift_amount(shift_amount: int) -> None:
74+
"""Test bitshift left operation with varying shift amounts."""
75+
shape = (5, 5)
76+
value = wrapper.randu(shape, dtype.int16)
77+
shift_amount_arr = wrapper.constant(shift_amount, shape, dtype.int16)
78+
79+
result = wrapper.bitshiftl(value, shift_amount_arr)
80+
81+
assert (wrapper.get_dims(result)[0], wrapper.get_dims(result)[1]) == shape
82+
83+
84+
@pytest.mark.parametrize(
85+
"shape_a, shape_b",
86+
[
87+
((1, 5), (5, 1)), # 2D with 2D inverse
88+
((5, 5), (5, 1)), # 2D with 2D
89+
((5, 5), (1, 1)), # 2D with 2D
90+
((1, 1, 1), (5, 5, 5)), # 3D with 3D
91+
],
92+
)
93+
def test_bitshiftl_different_shapes(shape_a: tuple, shape_b: tuple) -> None:
94+
"""Test if left bit shifting handles arrays of different shapes"""
95+
with pytest.raises(RuntimeError):
96+
values = wrapper.randu(shape_a, dtype.int16)
97+
bits_to_shift = wrapper.constant(1, shape_b, dtype.int16)
98+
result = wrapper.bitshiftl(values, bits_to_shift)
99+
print(array_to_string("", result, 3, False))
100+
assert (
101+
wrapper.get_dims(result)[0 : len(shape_a)] == shape_a # noqa
102+
), f"Failed for shapes {shape_a} and {shape_b}"
103+
104+
105+
@pytest.mark.parametrize("shift_amount", [-1, 0, 2, 30])
106+
def test_bitshift_right_varying_shift_amount(shift_amount: int) -> None:
107+
"""Test bitshift right operation with varying shift amounts."""
108+
shape = (5, 5)
109+
value = wrapper.randu(shape, dtype.int16)
110+
shift_amount_arr = wrapper.constant(shift_amount, shape, dtype.int16)
111+
112+
result = wrapper.bitshiftr(value, shift_amount_arr)
113+
114+
assert (wrapper.get_dims(result)[0], wrapper.get_dims(result)[1]) == shape
115+
116+
117+
@pytest.mark.parametrize("dtype_name", get_real_types())
118+
def test_bitshiftr_dtypes(dtype_name: dtype.Dtype) -> None:
119+
"""Test bit shift operation across all supported data types."""
120+
check_type_supported(dtype_name)
121+
if dtype_name == dtype.f16 or dtype_name == dtype.f32:
122+
pytest.skip()
123+
shape = (5, 5)
124+
values = wrapper.randu(shape, dtype_name)
125+
bits_to_shift = wrapper.constant(1, shape, dtype_name)
126+
127+
result = wrapper.bitshiftr(values, bits_to_shift)
128+
129+
assert dtype.c_api_value_to_dtype(wrapper.get_type(result)) == dtype_name, f"Failed for dtype: {dtype_name}"
130+
131+
132+
@pytest.mark.parametrize(
133+
"invdtypes",
134+
[
135+
dtype.c32,
136+
dtype.f64,
137+
],
138+
)
139+
def test_bitshiftr_supported_dtypes(invdtypes: dtype.Dtype) -> None:
140+
"""Test bitshift operations for unsupported integer data types."""
141+
shape = (5, 5)
142+
with pytest.raises(RuntimeError):
143+
value = wrapper.randu(shape, invdtypes)
144+
shift_amount = wrapper.constant(1, shape, invdtypes)
145+
146+
result = wrapper.bitshiftr(value, shift_amount)
147+
assert dtype.c_api_value_to_dtype(wrapper.get_type(result)) == invdtypes, f"Failed for dtype: {invdtypes}"
148+
149+
150+
@pytest.mark.parametrize("input_size", [8, 10, 12])
151+
def test_bitshift_right_varying_input_size(input_size: int) -> None:
152+
"""Test bitshift right operation with varying input sizes"""
153+
shape = (input_size, input_size)
154+
value = wrapper.randu(shape, dtype.int16)
155+
shift_amount = wrapper.constant(1, shape, dtype.int16) # Fixed shift amount for simplicity
156+
157+
result = wrapper.bitshiftr(value, shift_amount)
158+
159+
assert (wrapper.get_dims(result)[0], wrapper.get_dims(result)[1]) == shape
160+
161+
162+
@pytest.mark.parametrize(
163+
"shape",
164+
[
165+
(10,),
166+
(5, 5),
167+
(2, 3, 4),
168+
],
169+
)
170+
def test_bitshiftr_varying_shapes(shape: tuple) -> None:
171+
"""Test right bit shifting with arrays of varying shapes."""
172+
values = wrapper.randu(shape, dtype.int16)
173+
bits_to_shift = wrapper.constant(1, shape, dtype.int16)
174+
175+
result = wrapper.bitshiftr(values, bits_to_shift)
176+
177+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa
178+
179+
180+
@pytest.mark.parametrize(
181+
"shape_a, shape_b",
182+
[
183+
((1, 5), (5, 1)), # 2D with 2D inverse
184+
((5, 5), (5, 1)), # 2D with 2D
185+
((5, 5), (1, 1)), # 2D with 2D
186+
((1, 1, 1), (5, 5, 5)), # 3D with 3D
187+
],
188+
)
189+
def test_bitshiftr_different_shapes(shape_a: tuple, shape_b: tuple) -> None:
190+
"""Test if right bit shifting handles arrays of different shapes"""
191+
with pytest.raises(RuntimeError):
192+
values = wrapper.randu(shape_a, dtype.int16)
193+
bits_to_shift = wrapper.constant(1, shape_b, dtype.int16)
194+
result = wrapper.bitshiftr(values, bits_to_shift)
195+
print(array_to_string("", result, 3, False))
196+
assert (
197+
wrapper.get_dims(result)[0 : len(shape_a)] == shape_a # noqa
198+
), f"Failed for shapes {shape_a} and {shape_b}"

0 commit comments

Comments
 (0)