Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f4f06fe

Browse files
committedJan 31, 2024
Fix missing imports. Fix return issue in svd_inplace
1 parent 2252f36 commit f4f06fe

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed
 

Diff for: ‎arrayfire_wrapper/lib/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@
439439
"is_lapack_available",
440440
"cholesky",
441441
"cholesky_inplace",
442+
"gemm",
442443
"lu",
443444
"lu_inplace",
444445
"qr",
@@ -478,6 +479,7 @@
478479
det,
479480
dot,
480481
dot_all,
482+
gemm,
481483
inverse,
482484
is_lapack_available,
483485
lu,
@@ -923,6 +925,6 @@
923925

924926
# Constants
925927

926-
__all__ += ["Match", "Moment", "Pad", "PointerSource", "TopK", "VarianceBias"]
928+
__all__ += ["Match", "Moment", "Pad", "PointerSource", "TopK", "VarianceBias", "MatProp", "ImageFormat"]
927929

928-
from ._constants import Match, Moment, Pad, PointerSource, TopK, VarianceBias
930+
from ._constants import ImageFormat, Match, MatProp, Moment, Pad, PointerSource, TopK, VarianceBias

Diff for: ‎arrayfire_wrapper/lib/linear_algebra/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# flake8: noqa
22

3-
__all__ = ["dot", "dot_all", "matmul"]
4-
from .blas_operations import dot, dot_all, matmul
3+
__all__ = ["dot", "dot_all", "matmul", "gemm"]
4+
from .blas_operations import dot, dot_all, gemm, matmul
55

66
__all__ += ["is_lapack_available"]
77
from .lapack_helpers import is_lapack_available

Diff for: ‎arrayfire_wrapper/lib/linear_algebra/matrix_factorization_and_decomposition.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@ def svd(arr: AFArray, /) -> tuple[AFArray, AFArray, AFArray]:
7474
return (u, s, vt)
7575

7676

77-
def svd_inplace(arr: AFArray, /) -> tuple[AFArray, AFArray, AFArray]:
77+
def svd_inplace(arr: AFArray, /) -> tuple[AFArray, AFArray, AFArray, AFArray]:
7878
"""
7979
source: https://arrayfire.org/docs/group__lapack__factor__func__svd.htm#ga80b31f7671bf00143dd992df8d585a2d
8080
"""
8181
u = AFArray.create_null_pointer()
8282
s = AFArray.create_null_pointer()
8383
vt = AFArray.create_null_pointer()
8484
call_from_clib(svd.__name__, ctypes.pointer(u), ctypes.pointer(s), ctypes.pointer(vt), arr)
85-
return (u, s, vt)
85+
return (u, s, vt, arr)

Diff for: ‎arrayfire_wrapper/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22

33
_MAJOR = "0"
4-
_MINOR = "2"
4+
_MINOR = "3"
55
# On main and in a nightly release the patch should be one ahead of the last
66
# released build.
77
_PATCH = "0"

0 commit comments

Comments
 (0)
Please sign in to comment.