Skip to content

Commit 2b095bb

Browse files
committed
move unified functions to Backend class
1 parent 9d23ed2 commit 2b095bb

File tree

4 files changed

+47
-92
lines changed

4 files changed

+47
-92
lines changed

README.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# arrayfire-binary-python-wrapper
22

3-
[ArrayFire](https://github.com/arrayfire/arrayfire) is a high performance library for parallel computing with an easy-to-use API. It enables users to write scientific computing code that is portable across CUDA, OpenCL and CPU devices.
4-
This project is meant to provide thin Python bindings for the ArrayFire C library. It also decouples releases of the main C/C++ library from the Python library by acting as a intermediate library and only wrapping the provided C calls.
5-
This allows the building of large binary wheels only when the underlying ArrayFire version is increased, and the fully-featured Python library can be developed atop independently. This package is not intended to be used directly and merely exposes the
3+
[ArrayFire](https://github.com/arrayfire/arrayfire) is a high performance library for parallel computing with an easy-to-use API. It enables users to write scientific computing code that is portable across CUDA, OpenCL and CPU devices.
4+
5+
This project is meant to provide thin Python bindings for the ArrayFire C library. It also decouples releases of the main C/C++ library from the Python library by acting as a intermediate library and only wrapping the provided C calls.
6+
7+
This allows the building of large binary wheels only when the underlying ArrayFire version is increased, and the fully-featured Python library can be developed atop independently. The package is not intended to be used directly and merely exposes the
68
C functionality required by downstream implementations. This package can exist in two forms, with a bundled binary distribution, or merely as a loader that will load the ArrayFire library from a system or user level install.
79

810
# Installing

arrayfire_wrapper/_backend.py

+42-11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import platform
77
import sys
8+
from arrayfire_wrapper.defines import AFArray
89
from dataclasses import dataclass
910
from enum import Enum
1011
from pathlib import Path
@@ -13,7 +14,6 @@
1314

1415
from .defines import is_arch_x86
1516
from .version import ARRAYFIRE_VER_MAJOR
16-
from arrayfire_wrapper.lib.unified_api_functions import set_backend as unified_set_backend
1717

1818
VERBOSE_LOADS = os.environ.get("AF_VERBOSE_LOADS", "") == "1"
1919

@@ -149,7 +149,7 @@ def _find_site_local_path() -> Path:
149149
print( lpath)
150150
print( lpath / module_name / "binaries")
151151
return lpath / module_name / "binaries"
152-
raise ValueError("No binaries detected in site path.")
152+
raise RuntimeError("No binaries detected in site path.")
153153

154154
def _find_default_path(*args: str) -> Path:
155155
for path in args:
@@ -183,15 +183,14 @@ def __init__(self) -> None:
183183
self._load_backend_libs()
184184
self._load_forge_lib()
185185

186-
def set_backend(self, backend_type : BackendType) -> None:
186+
def _change_backend(self, backend_type : BackendType) -> None:
187187
# if unified is available, do dynamic module loading through libaf
188188
if self._backend_type == BackendType.unified:
189+
from arrayfire_wrapper.lib.unified_api_functions import set_backend as unified_set_backend
189190
try:
190191
unified_set_backend(backend_type)
191-
except RuntimeError:
192-
if VERBOSE_LOADS:
193-
print(f"Unable to change backend using unified loader")
194-
raise RuntimeError
192+
except RuntimeError as e:
193+
print(f"Unable to change backend using unified loader: {str(e)}")
195194
# if unified not available
196195
else:
197196
if backend_type in self._clibs:
@@ -273,9 +272,10 @@ def _lib_names(self, name: str, lib: _LibPrefixes, ver_major: str | None = None)
273272
try:
274273
local_path = _find_site_local_path()
275274
lib_paths.append(local_path / lib_name)
276-
except ValueError as e:
275+
except RuntimeError as e:
277276
if VERBOSE_LOADS:
278-
print(str(e))
277+
print(f"Moving on to system libraries, site local load failed due to: {str(e)}")
278+
pass
279279

280280
if self._backend_path_config.af_path: # prefer specified AF_PATH if exists
281281
lib64_path = self._backend_path_config.af_path / "lib64"
@@ -294,6 +294,38 @@ def _find_nvrtc_builtins_lib_name(self, search_path: Path) -> str | None:
294294
return f.name
295295
return None
296296

297+
298+
# unified backend functions
299+
def get_active_backend(self) -> str:
300+
if self._backend_type == BackendType.unified:
301+
from arrayfire_wrapper.lib.unified_api_functions import get_active_backend as unified_get_active_backend
302+
return unified_get_active_backend()
303+
raise RuntimeError("Using unified function on non-unified backend")
304+
305+
def get_available_backends(self) -> list[int]:
306+
if self._backend_type == BackendType.unified:
307+
from arrayfire_wrapper.lib.unified_api_functions import get_available_backends as unified_get_available_backends
308+
return unified_get_available_backends()
309+
raise RuntimeError("Using unified function on non-unified backend")
310+
311+
def get_backend_count(self) -> int:
312+
if self._backend_type == BackendType.unified:
313+
from arrayfire_wrapper.lib.unified_api_functions import get_backend_count as unified_get_backend_count
314+
return unified_get_backend_count()
315+
raise RuntimeError("Using unified function on non-unified backend")
316+
317+
def get_backend_id(self, arr: AFArray, /) -> int:
318+
if self._backend_type == BackendType.unified:
319+
from arrayfire_wrapper.lib.unified_api_functions import get_backend_id as unified_get_backend_id
320+
return unified_get_backend_id(arr)
321+
raise RuntimeError("Using unified function on non-unified backend")
322+
323+
def get_device_id(self, arr: AFArray, /) -> int:
324+
if self._backend_type == BackendType.unified:
325+
from arrayfire_wrapper.lib.unified_api_functions import get_device_id as unified_get_device_id
326+
return unified_get_device_id(arr)
327+
raise RuntimeError("Using unified function on non-unified backend")
328+
297329
@property
298330
def backend_type(self) -> BackendType:
299331
return self._backend_type
@@ -320,10 +352,9 @@ def get_backend() -> Backend:
320352
return __backend
321353

322354
def set_backend(backend_type : BackendType) -> None:
323-
324355
try:
325356
backend = get_backend()
326-
backend.set_backend(backend_type)
357+
backend._change_backend(backend_type)
327358
except RuntimeError:
328359
print(f"Requested backend {backend_type.name} could not be found")
329360

arrayfire_wrapper/lib/__init__.py

-20
Original file line numberDiff line numberDiff line change
@@ -910,26 +910,6 @@
910910
approx2_v2,
911911
)
912912

913-
# Unified API functions
914-
915-
__all__ += [
916-
"get_active_backend",
917-
"get_available_backends",
918-
"get_backend_count",
919-
"get_backend_id",
920-
"get_device_id",
921-
"set_backend",
922-
]
923-
924-
from .unified_api_functions import (
925-
get_active_backend,
926-
get_available_backends,
927-
get_backend_count,
928-
get_backend_id,
929-
get_device_id,
930-
set_backend,
931-
)
932-
933913
# Events
934914

935915
__all__ += ["AFEvent", "block_event", "create_event", "delete_event", "enqueue_wait_event", "mark_event"]

arrayfire_wrapper/lib/unified_api_functions.py

-58
This file was deleted.

0 commit comments

Comments
 (0)