Skip to content

Commit

Permalink
python312Packages.tinygrad: 0.9.0 -> 0.9.2
Browse files Browse the repository at this point in the history
  • Loading branch information
GaetanLepage committed Sep 18, 2024
1 parent e5074df commit 7643022
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 54 deletions.
66 changes: 34 additions & 32 deletions pkgs/development/python-modules/tinygrad/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,43 @@
ocl-icd,
stdenv,
rocmPackages,

# build-system
setuptools,
wheel,

# dependencies
numpy,
tqdm,
# nativeCheckInputs

# tests
blobfile,
bottle,
clang,
hexdump,
hypothesis,
librosa,
onnx,
pillow,
pydot,
pytest-xdist,
pytestCheckHook,
safetensors,
sentencepiece,
tiktoken,
torch,
tqdm,
transformers,
}:

buildPythonPackage rec {
pname = "tinygrad";
version = "0.9.0";
version = "0.9.2";
pyproject = true;

src = fetchFromGitHub {
owner = "tinygrad";
repo = "tinygrad";
rev = "refs/tags/v${version}";
hash = "sha256-opBxciETZruZjHqz/3vO7rogzjvVJKItulIiok/Zs2Y=";
hash = "sha256-fCKtJhZtqq6yjc6m41uvikzM9GArUlB8Q7jN/Np8+SM=";
};

patches = [
Expand All @@ -62,29 +67,20 @@ buildPythonPackage rec {
substituteInPlace tinygrad/runtime/autogen/opencl.py \
--replace-fail "ctypes.util.find_library('OpenCL')" "'${ocl-icd}/lib/libOpenCL.so'"
''
# hipGetDevicePropertiesR0600 is a symbol from rocm-6. We are currently at rocm-5.
# We are not sure that this works. Remove when rocm gets updated to version 6.
+ lib.optionalString rocmSupport ''
substituteInPlace extra/hip_gpu_driver/hip_ioctl.py \
--replace-fail "processor = platform.processor()" "processor = ${stdenv.hostPlatform.linuxArch}"
substituteInPlace tinygrad/runtime/autogen/hip.py \
--replace-fail "/opt/rocm/lib/libamdhip64.so" "${rocmPackages.clr}/lib/libamdhip64.so" \
--replace-fail "/opt/rocm/lib/libhiprtc.so" "${rocmPackages.clr}/lib/libhiprtc.so" \
--replace-fail "hipGetDevicePropertiesR0600" "hipGetDeviceProperties"
substituteInPlace tinygrad/runtime/autogen/comgr.py \
--replace-fail "/opt/rocm/lib/libamd_comgr.so" "${rocmPackages.rocm-comgr}/lib/libamd_comgr.so"
'';

build-system = [
setuptools
wheel
];
build-system = [ setuptools ];

dependencies =
[
numpy
tqdm
]
++ lib.optionals stdenv.isDarwin [
# pyobjc-framework-libdispatch
Expand All @@ -94,18 +90,22 @@ buildPythonPackage rec {
pythonImportsCheck = [ "tinygrad" ];

nativeCheckInputs = [
blobfile
bottle
clang
hexdump
hypothesis
librosa
onnx
pillow
pydot
pytest-xdist
pytestCheckHook
safetensors
sentencepiece
tiktoken
torch
tqdm
transformers
];

Expand All @@ -115,6 +115,10 @@ buildPythonPackage rec {

disabledTests =
[
# flaky: https://github.com/tinygrad/tinygrad/issues/6542
# TODO: re-enable when https://github.com/tinygrad/tinygrad/pull/6560 gets merged
"test_broadcastdot"

# Require internet access
"test_benchmark_openpilot_model"
"test_bn_alone"
Expand All @@ -129,12 +133,14 @@ buildPythonPackage rec {
"test_e2e_big"
"test_fetch_small"
"test_huggingface_enet_safetensors"
"test_index_mnist"
"test_linear_mnist"
"test_load_convnext"
"test_load_enet"
"test_load_enet_alt"
"test_load_llama2bfloat"
"test_load_resnet"
"test_mnist_val"
"test_openpilot_model"
"test_resnet"
"test_shufflenet"
Expand All @@ -148,32 +154,28 @@ buildPythonPackage rec {
]
# Fail on aarch64-linux with AssertionError
++ lib.optionals (stdenv.hostPlatform.system == "aarch64-linux") [
"test_casts_to"
"test_casts_to"
"test_int8_to_uint16_negative"
"test_casts_to"
"test_casts_to"
"test_casts_from"
"test_casts_to"
"test_int8"
"test_casts_to"
"test_int8_to_uint16_negative"
];

disabledTestPaths =
[
# Require internet access
"test/models/test_mnist.py"
"test/models/test_real_world.py"
"test/testextra/test_lr_scheduler.py"
]
++ lib.optionals (!rocmSupport) [ "extra/hip_gpu_driver/" ];
disabledTestPaths = [
# Require internet access
"test/models/test_mnist.py"
"test/models/test_real_world.py"
"test/testextra/test_lr_scheduler.py"

# Files under this directory are not considered as tests by upstream and should be skipped
"extra/"
];

meta = with lib; {
meta = {
description = "Simple and powerful neural network framework";
homepage = "https://github.com/tinygrad/tinygrad";
changelog = "https://github.com/tinygrad/tinygrad/releases/tag/v${version}";
license = licenses.mit;
maintainers = with maintainers; [ GaetanLepage ];
license = lib.licenses.mit;
maintainers = with lib.maintainers; [ GaetanLepage ];
# Requires unpackaged pyobjc-framework-libdispatch and pyobjc-framework-metal
broken = stdenv.isDarwin;
};
Expand Down
61 changes: 39 additions & 22 deletions pkgs/development/python-modules/tinygrad/fix-dlopen-cuda.patch
Original file line number Diff line number Diff line change
@@ -1,32 +1,49 @@
diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py
index 359083a9..3cd5f7be 100644
index a30c8f53..e2078ff6 100644
--- a/tinygrad/runtime/autogen/cuda.py
+++ b/tinygrad/runtime/autogen/cuda.py
@@ -143,10 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'):
return ctypes.cast(string, ctypes.POINTER(ctypes.c_char))
@@ -145,7 +145,19 @@ def char_pointer_cast(string, encoding='utf-8'):


+NAME_TO_PATHS = {
+ "libcuda.so": ["@driverLink@/lib/libcuda.so"],
+ "libnvrtc.so": ["@libnvrtc@"],
+}
+def _try_dlopen(name):
+ try:
+ return ctypes.CDLL(name)
+ except OSError:
+ pass
+ for candidate in NAME_TO_PATHS.get(name, []):
+ try:
+ return ctypes.CDLL(candidate)
+ except OSError:
+ pass
+ raise RuntimeError(f"{name} not found")

_libraries = {}
-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda'))
-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
+_libraries['libcuda.so'] = _try_dlopen('libcuda.so')
+_libraries['libnvrtc.so'] = _try_dlopen('libnvrtc.so')
+libcuda = None
+try:
+ libcuda = ctypes.CDLL('libcuda.so')
+except OSError:
+ pass
+try:
+ libcuda = ctypes.CDLL('@driverLink@/lib/libcuda.so')
+except OSError:
+ pass
+if libcuda is None:
+ raise RuntimeError(f"`libcuda.so` not found")
+
+_libraries['libcuda.so'] = libcuda


cuuint32_t = ctypes.c_uint32
diff --git a/tinygrad/runtime/autogen/nvrtc.py b/tinygrad/runtime/autogen/nvrtc.py
index 6af74187..c5a6c6c4 100644
--- a/tinygrad/runtime/autogen/nvrtc.py
+++ b/tinygrad/runtime/autogen/nvrtc.py
@@ -10,7 +10,18 @@ import ctypes, ctypes.util


_libraries = {}
-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
+libnvrtc = None
+try:
+ libnvrtc = ctypes.CDLL('libnvrtc.so')
+except OSError:
+ pass
+try:
+ libnvrtc = ctypes.CDLL('@libnvrtc@')
+except OSError:
+ pass
+if libnvrtc is None:
+ raise RuntimeError(f"`libnvrtc.so` not found")
+_libraries['libnvrtc.so'] = ctypes.CDLL(libnvrtc)
def string_cast(char_pointer, encoding='utf-8', errors='strict'):
value = ctypes.cast(char_pointer, ctypes.c_char_p).value
if value is not None and encoding is not None:

0 comments on commit 7643022

Please sign in to comment.