forked from NixOS/nixpkgs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
python312Packages.tinygrad: 0.9.0 -> 0.9.2
Diff: tinygrad/tinygrad@refs/tags/v0.9.0...v0.9.2 Changelog: https://github.com/tinygrad/tinygrad/releases/tag/v0.9.2
- Loading branch information
1 parent
e5074df
commit 7643022
Showing
2 changed files
with
73 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 39 additions & 22 deletions
61
pkgs/development/python-modules/tinygrad/fix-dlopen-cuda.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,49 @@ | ||
diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py | ||
index 359083a9..3cd5f7be 100644 | ||
index a30c8f53..e2078ff6 100644 | ||
--- a/tinygrad/runtime/autogen/cuda.py | ||
+++ b/tinygrad/runtime/autogen/cuda.py | ||
@@ -143,10 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'): | ||
return ctypes.cast(string, ctypes.POINTER(ctypes.c_char)) | ||
@@ -145,7 +145,19 @@ def char_pointer_cast(string, encoding='utf-8'): | ||
|
||
|
||
+NAME_TO_PATHS = { | ||
+ "libcuda.so": ["@driverLink@/lib/libcuda.so"], | ||
+ "libnvrtc.so": ["@libnvrtc@"], | ||
+} | ||
+def _try_dlopen(name): | ||
+ try: | ||
+ return ctypes.CDLL(name) | ||
+ except OSError: | ||
+ pass | ||
+ for candidate in NAME_TO_PATHS.get(name, []): | ||
+ try: | ||
+ return ctypes.CDLL(candidate) | ||
+ except OSError: | ||
+ pass | ||
+ raise RuntimeError(f"{name} not found") | ||
|
||
_libraries = {} | ||
-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda')) | ||
-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc')) | ||
+_libraries['libcuda.so'] = _try_dlopen('libcuda.so') | ||
+_libraries['libnvrtc.so'] = _try_dlopen('libnvrtc.so') | ||
+libcuda = None | ||
+try: | ||
+ libcuda = ctypes.CDLL('libcuda.so') | ||
+except OSError: | ||
+ pass | ||
+try: | ||
+ libcuda = ctypes.CDLL('@driverLink@/lib/libcuda.so') | ||
+except OSError: | ||
+ pass | ||
+if libcuda is None: | ||
+ raise RuntimeError(f"`libcuda.so` not found") | ||
+ | ||
+_libraries['libcuda.so'] = libcuda | ||
|
||
|
||
cuuint32_t = ctypes.c_uint32 | ||
diff --git a/tinygrad/runtime/autogen/nvrtc.py b/tinygrad/runtime/autogen/nvrtc.py | ||
index 6af74187..c5a6c6c4 100644 | ||
--- a/tinygrad/runtime/autogen/nvrtc.py | ||
+++ b/tinygrad/runtime/autogen/nvrtc.py | ||
@@ -10,7 +10,18 @@ import ctypes, ctypes.util | ||
|
||
|
||
_libraries = {} | ||
-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc')) | ||
+libnvrtc = None | ||
+try: | ||
+ libnvrtc = ctypes.CDLL('libnvrtc.so') | ||
+except OSError: | ||
+ pass | ||
+try: | ||
+ libnvrtc = ctypes.CDLL('@libnvrtc@') | ||
+except OSError: | ||
+ pass | ||
+if libnvrtc is None: | ||
+ raise RuntimeError(f"`libnvrtc.so` not found") | ||
+_libraries['libnvrtc.so'] = ctypes.CDLL(libnvrtc) | ||
def string_cast(char_pointer, encoding='utf-8', errors='strict'): | ||
value = ctypes.cast(char_pointer, ctypes.c_char_p).value | ||
if value is not None and encoding is not None: |