9
9
from enum import Enum
10
10
from pathlib import Path
11
11
from typing import Iterator
12
+ import sysconfig
12
13
13
14
from .defines import is_arch_x86
14
15
from .version import ARRAYFIRE_VER_MAJOR
@@ -36,6 +37,7 @@ class _BackendPathConfig:
36
37
lib_prefix : str
37
38
lib_postfix : str
38
39
af_path : Path
40
+ af_is_user_path : bool
39
41
cuda_found : bool
40
42
41
43
def __iter__ (self ) -> Iterator :
@@ -46,16 +48,20 @@ def _get_backend_path_config() -> _BackendPathConfig:
46
48
platform_name = platform .system ()
47
49
cuda_found = False
48
50
51
+ # try to use user provided AF_PATH if explicitly set
49
52
try :
50
53
af_path = Path (os .environ ["AF_PATH" ])
54
+ af_is_user_path = True
51
55
except KeyError :
52
56
af_path = None
57
+ af_is_user_path = False
53
58
54
59
try :
55
60
cuda_path = Path (os .environ ["CUDA_PATH" ])
56
61
except KeyError :
57
62
cuda_path = None
58
63
64
+ # try to find default arrayfire installation paths
59
65
if platform_name == _SupportedPlatforms .windows .value or _SupportedPlatforms .is_cygwin (platform_name ):
60
66
if platform_name == _SupportedPlatforms .windows .value :
61
67
# HACK Supressing crashes caused by missing dlls
@@ -64,40 +70,84 @@ def _get_backend_path_config() -> _BackendPathConfig:
64
70
ctypes .windll .kernel32 .SetErrorMode (0x0001 | 0x0002 ) # type: ignore[attr-defined]
65
71
66
72
if not af_path :
67
- af_path = _find_default_path (f"C:/Program Files/ArrayFire/v{ ARRAYFIRE_VER_MAJOR } " )
73
+ try :
74
+ af_path = _find_default_path (f"C:/Program Files/ArrayFire/v{ ARRAYFIRE_VER_MAJOR } " )
75
+ except ValueError :
76
+ af_path = None
77
+
68
78
69
79
if cuda_path and (cuda_path / "bin" ).is_dir () and (cuda_path / "nvvm/bin" ).is_dir ():
70
80
cuda_found = True
71
81
72
- return _BackendPathConfig ("" , ".dll" , af_path , cuda_found )
82
+ return _BackendPathConfig ("" , ".dll" , af_path , af_is_user_path , cuda_found )
73
83
74
84
if platform_name == _SupportedPlatforms .darwin .value :
75
85
default_cuda_path = Path ("/usr/local/cuda/" )
76
86
77
87
if not af_path :
78
88
af_path = _find_default_path ("/opt/arrayfire" , "/usr/local" )
89
+ try :
90
+ af_path = _find_default_path (f"C:/Program Files/ArrayFire/v{ ARRAYFIRE_VER_MAJOR } " ,
91
+ "C:/Program Files (x86)/ArrayFire/v{ARRAYFIRE_VER_MAJOR}" )
92
+ except ValueError :
93
+ af_path = None
79
94
80
95
if not (cuda_path and default_cuda_path .exists ()):
81
96
cuda_found = (default_cuda_path / "lib" ).is_dir () and (default_cuda_path / "/nvvm/lib" ).is_dir ()
82
97
83
- return _BackendPathConfig ("lib" , f".{ ARRAYFIRE_VER_MAJOR } .dylib" , af_path , cuda_found )
98
+ return _BackendPathConfig ("lib" , f".{ ARRAYFIRE_VER_MAJOR } .dylib" , af_path , af_is_user_path , cuda_found )
84
99
85
100
if platform_name == _SupportedPlatforms .linux .value :
86
101
default_cuda_path = Path ("/usr/local/cuda/" )
87
102
88
103
if not af_path :
89
- af_path = _find_default_path (f"/opt/arrayfire-{ ARRAYFIRE_VER_MAJOR } " , "/opt/arrayfire/" , "/usr/local/" )
104
+ try :
105
+ af_path = _find_default_path (f"/opt/arrayfire-{ ARRAYFIRE_VER_MAJOR } " , "/opt/arrayfire/" , "/usr/local/" )
106
+ except ValueError :
107
+ af_path = None
90
108
91
109
if not (cuda_path and default_cuda_path .exists ()):
92
110
if "64" in platform .architecture ()[0 ]: # Check either is 64 bit arch is selected
93
111
cuda_found = (default_cuda_path / "lib64" ).is_dir () and (default_cuda_path / "nvvm/lib64" ).is_dir ()
94
112
else :
95
113
cuda_found = (default_cuda_path / "lib" ).is_dir () and (default_cuda_path / "nvvm/lib" ).is_dir ()
96
114
97
- return _BackendPathConfig ("lib" , f".so.{ ARRAYFIRE_VER_MAJOR } " , af_path , cuda_found )
115
+ return _BackendPathConfig ("lib" , f".so.{ ARRAYFIRE_VER_MAJOR } " , af_path , af_is_user_path , cuda_found )
98
116
99
117
raise OSError (f"{ platform_name } is not supported." )
100
118
119
+ # finds paths to locally packaged arrayfire libraries if they exist in site
120
+ def _find_site_local_path () -> Path :
121
+ local_paths = ["." ]
122
+
123
+ # module search paths
124
+ af_module = __import__ (__name__ )
125
+ module_paths = af_module .__path__ if af_module .__path__ else []
126
+ for path in module_paths :
127
+ local_paths .append (path )
128
+
129
+ # site search path
130
+ purelib_path = sysconfig .get_path ('purelib' )
131
+ platlib_path = sysconfig .get_path ('platlib' )
132
+ local_paths .append (purelib_path )
133
+ local_paths .append (platlib_path )
134
+
135
+ # sys search path
136
+ local_paths .extend (sys .path )
137
+
138
+ module_name = af_module .__name__
139
+ for path in local_paths :
140
+ lpath = Path (path )
141
+ if lpath .exists ():
142
+ p = lpath .glob (f"{ module_name } /binaries/*" )
143
+ files = [x .name for x in p if x .is_file ()]
144
+ query_libnames = ['afcpu' , 'afoneapi' , 'afopencl' , 'afcuda' , 'af' , 'forge' ]
145
+ found_lib_in_dir = any (q in f for q in query_libnames for f in files )
146
+ if found_lib_in_dir :
147
+ print ( lpath )
148
+ print ( lpath / module_name / "binaries" )
149
+ return lpath / module_name / "binaries"
150
+ raise ValueError ("No binaries detected in site path." )
101
151
102
152
def _find_default_path (* args : str ) -> Path :
103
153
for path in args :
@@ -108,26 +158,51 @@ def _find_default_path(*args: str) -> Path:
108
158
109
159
110
160
class BackendType (enum .Enum ): # TODO change name - avoid using _backend_type - e.g. type
111
- unified = 0 # NOTE It is set as Default value on Arrayfire backend
112
- cpu = 1
113
161
cuda = 2
114
162
opencl = 4
115
163
oneapi = 8
164
+ cpu = 1
165
+ unified = 0 # NOTE It is set as Default value on Arrayfire backend
116
166
117
167
def __iter__ (self ) -> Iterator :
118
168
# NOTE cpu comes last because we want to keep this order priorty during backend initialization
119
- return iter ((self .unified , self .cuda , self .oneapi , self .opencl , self .cpu ))
169
+ return iter ((self .unified , self .cuda , self .opencl , self .oneapi , self .cpu ))
120
170
121
171
122
172
class Backend :
123
173
_backend_type : BackendType
124
- _clib : ctypes .CDLL
174
+ _clibs : dict [ BackendType , ctypes .CDLL ]
125
175
126
176
def __init__ (self ) -> None :
127
177
self ._backend_path_config = _get_backend_path_config ()
128
178
129
- self ._load_forge_lib ()
179
+ self ._backend_type = None
180
+ self ._clibs = {}
130
181
self ._load_backend_libs ()
182
+ self ._load_forge_lib ()
183
+
184
+ def set_backend (self , backend_type : BackendType ) -> None :
185
+ # if unified is available, do dynamic module loading through libaf
186
+ if self ._backend_type == BackendType .unified :
187
+ import pdb ;pdb .set_trace ()
188
+ from arrayfire_wrapper .lib .unified_api_functions import set_backend as unified_set_backend
189
+ try :
190
+ unified_set_backend (backend_type )
191
+ except RuntimeError :
192
+ if VERBOSE_LOADS :
193
+ print (f"Unable to change backend using unified loader" )
194
+ raise RuntimeError
195
+ # if unified not available
196
+ else :
197
+ if backend_type in self ._clibs :
198
+ self ._backend_type = backend_type
199
+ else :
200
+ self ._backend_path_config = _get_backend_path_config ()
201
+
202
+ self ._backend_type = None
203
+ #self._clib = None
204
+ self ._load_backend_libs (backend_type )
205
+ #self._load_forge_lib() needed to reload?
131
206
132
207
def _load_forge_lib (self ) -> None :
133
208
for lib_name in self ._lib_names ("forge" , _LibPrefixes .forge ):
@@ -141,16 +216,18 @@ def _load_forge_lib(self) -> None:
141
216
print (f"Unable to load { lib_name } " )
142
217
pass
143
218
144
- def _load_backend_libs (self ) -> None :
145
- for backend_type in BackendType :
219
+ def _load_backend_libs (self , specific_backend : BackendType | None = None ) -> None :
220
+ available_backends = [specific_backend ] if specific_backend else list (BackendType )
221
+ for backend_type in available_backends :
222
+ print (backend_type )
146
223
self ._load_backend_lib (backend_type )
147
224
148
225
if self ._backend_type :
149
226
if VERBOSE_LOADS :
150
227
print (f"Setting { backend_type .name } as backend." )
151
228
break
152
229
153
- if not self ._backend_type and not self ._clib :
230
+ if not self ._backend_type and not self ._clibs :
154
231
raise RuntimeError (
155
232
"Could not load any ArrayFire libraries.\n "
156
233
"Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information."
@@ -164,7 +241,7 @@ def _load_backend_lib(self, _backend_type: BackendType) -> None:
164
241
try :
165
242
ctypes .cdll .LoadLibrary (str (lib_name ))
166
243
self ._backend_type = _backend_type
167
- self ._clib = ctypes .CDLL (str (lib_name ))
244
+ self ._clibs [ _backend_type ] = ctypes .CDLL (str (lib_name ))
168
245
169
246
if _backend_type == BackendType .cuda :
170
247
self ._load_nvrtc_builtins_lib (lib_name .parent )
@@ -191,22 +268,22 @@ def _lib_names(self, name: str, lib: _LibPrefixes, ver_major: str | None = None)
191
268
post = self ._backend_path_config .lib_postfix if ver_major is None else ver_major
192
269
lib_name = self ._backend_path_config .lib_prefix + lib .value + name + post
193
270
194
- lib64_path = self ._backend_path_config .af_path / "lib64"
195
- search_path = lib64_path if lib64_path .is_dir () else self ._backend_path_config .af_path / "lib"
196
-
197
- site_path = Path (sys .prefix ) / "lib64" if not is_arch_x86 () else Path (sys .prefix ) / "lib"
198
-
199
- # prefer locally packaged arrayfire libraries if they exist
200
- af_module = __import__ (__name__ )
201
- local_path = Path (af_module .__path__ [0 ]) if af_module .__path__ else Path ("" )
271
+ lib_paths = [Path ("" , lib_name )]
202
272
203
- lib_paths = [Path ("" , lib_name ), site_path / lib_name , local_path / lib_name ]
273
+ # use local or site packaged arrayfire libraries if they exist
274
+ local_path = _find_site_local_path ()
275
+ lib_paths .append (local_path / lib_name )
204
276
205
277
if self ._backend_path_config .af_path : # prefer specified AF_PATH if exists
206
- return [search_path / lib_name ] + lib_paths
207
- else :
208
- lib_paths .insert (2 , Path (str (search_path ), lib_name ))
209
- return lib_paths
278
+ lib64_path = self ._backend_path_config .af_path / "lib64"
279
+ search_path = lib64_path if lib64_path .is_dir () else self ._backend_path_config .af_path / "lib"
280
+ # prefer path explicitly set by user through AF_PATH
281
+ if self ._backend_path_config .af_is_user_path :
282
+ return [search_path / lib_name ] + lib_paths
283
+ # otherwise, prefer to use site-packaged or local path
284
+ return lib_paths + [search_path / lib_name ]
285
+
286
+ return lib_paths
210
287
211
288
def _find_nvrtc_builtins_lib_name (self , search_path : Path ) -> str | None :
212
289
for f in search_path .iterdir ():
@@ -220,7 +297,7 @@ def backend_type(self) -> BackendType:
220
297
221
298
@property
222
299
def clib (self ) -> ctypes .CDLL :
223
- return self ._clib
300
+ return self ._clibs [ self . _backend_type ]
224
301
225
302
226
303
# Initialize the backend
@@ -238,3 +315,12 @@ def get_backend() -> Backend:
238
315
"""
239
316
240
317
return __backend
318
+
319
+ def set_backend (backend_type : BackendType ) -> None :
320
+
321
+ try :
322
+ backend = get_backend ()
323
+ backend .set_backend (backend_type )
324
+ except RuntimeError :
325
+ print (f"Requested backend { backend_type .name } could not be found" )
326
+
0 commit comments