Skip to content

Commit 8594c66

Browse files
committed
export set_backend
1 parent dd255c4 commit 8594c66

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

arrayfire_wrapper/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
"Backend",
1818
"BackendType",
1919
"get_backend",
20+
"set_backend",
2021
]
21-
from ._backend import Backend, BackendType, get_backend
22+
from ._backend import Backend, BackendType, get_backend, set_backend
2223

2324
__all__ += [
2425
"Dtype",

arrayfire_wrapper/_backend.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,9 @@ def _find_site_local_path() -> Path:
144144
query_libnames = ['afcpu', 'afoneapi', 'afopencl', 'afcuda', 'af', 'forge']
145145
found_lib_in_dir = any(q in f for q in query_libnames for f in files)
146146
if found_lib_in_dir:
147-
print( lpath)
148-
print( lpath / module_name / "binaries")
147+
if VERBOSE_LOADS:
148+
print( lpath)
149+
print( lpath / module_name / "binaries")
149150
return lpath / module_name / "binaries"
150151
raise ValueError("No binaries detected in site path.")
151152

@@ -166,7 +167,7 @@ class BackendType(enum.Enum): # TODO change name - avoid using _backend_type -
166167

167168
def __iter__(self) -> Iterator:
168169
# NOTE cpu comes last because we want to keep this order priorty during backend initialization
169-
return iter((self.unified, self.cuda, self.opencl, self.oneapi, self.cpu))
170+
return iter((self.cuda, self.opencl, self.oneapi, self.cpu, self.unified))
170171

171172

172173
class Backend:
@@ -184,7 +185,6 @@ def __init__(self) -> None:
184185
def set_backend(self, backend_type : BackendType) -> None:
185186
# if unified is available, do dynamic module loading through libaf
186187
if self._backend_type == BackendType.unified:
187-
import pdb;pdb.set_trace()
188188
from arrayfire_wrapper.lib.unified_api_functions import set_backend as unified_set_backend
189189
try:
190190
unified_set_backend(backend_type)
@@ -198,11 +198,8 @@ def set_backend(self, backend_type : BackendType) -> None:
198198
self._backend_type = backend_type
199199
else:
200200
self._backend_path_config = _get_backend_path_config()
201-
202-
self._backend_type = None
203-
#self._clib = None
204201
self._load_backend_libs(backend_type)
205-
#self._load_forge_lib() needed to reload?
202+
#self._load_forge_lib() # needed to reload?
206203

207204
def _load_forge_lib(self) -> None:
208205
for lib_name in self._lib_names("forge", _LibPrefixes.forge):
@@ -219,7 +216,6 @@ def _load_forge_lib(self) -> None:
219216
def _load_backend_libs(self, specific_backend : BackendType | None = None) -> None:
220217
available_backends = [specific_backend] if specific_backend else list(BackendType)
221218
for backend_type in available_backends:
222-
print(backend_type)
223219
self._load_backend_lib(backend_type)
224220

225221
if self._backend_type:
@@ -239,6 +235,8 @@ def _load_backend_lib(self, _backend_type: BackendType) -> None:
239235

240236
for lib_name in self._lib_names(name, _LibPrefixes.arrayfire):
241237
try:
238+
if VERBOSE_LOADS:
239+
print(f"Attempting to load {lib_name}")
242240
ctypes.cdll.LoadLibrary(str(lib_name))
243241
self._backend_type = _backend_type
244242
self._clibs[_backend_type] = ctypes.CDLL(str(lib_name))
@@ -271,8 +269,12 @@ def _lib_names(self, name: str, lib: _LibPrefixes, ver_major: str | None = None)
271269
lib_paths = [Path("", lib_name)]
272270

273271
# use local or site packaged arrayfire libraries if they exist
274-
local_path = _find_site_local_path()
275-
lib_paths.append(local_path / lib_name)
272+
try:
273+
local_path = _find_site_local_path()
274+
lib_paths.append(local_path / lib_name)
275+
except ValueError as e:
276+
if VERBOSE_LOADS:
277+
print(str(e))
276278

277279
if self._backend_path_config.af_path: # prefer specified AF_PATH if exists
278280
lib64_path = self._backend_path_config.af_path / "lib64"

0 commit comments

Comments
 (0)