Skip to content

Commit 2162737

Browse files
committedFeb 4, 2024
Add sigmoid, create_constant_array. Fix neg. Fix minor issues
1 parent 603da15 commit 2162737

File tree

7 files changed

+45
-13
lines changed

7 files changed

+45
-13
lines changed
 

Diff for: ‎arrayfire_wrapper/lib/__init__.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,14 @@
4141

4242
from .create_and_modify_array.assignment_and_indexing.lookup import lookup
4343

44-
__all__ += [
45-
"constant",
46-
"constant_complex",
47-
"constant_long",
48-
"constant_ulong",
49-
]
50-
from .create_and_modify_array.create_array.constant import constant, constant_complex, constant_long, constant_ulong
44+
__all__ += ["constant", "constant_complex", "constant_long", "constant_ulong", "create_constant_array"]
45+
from .create_and_modify_array.create_array.constant import (
46+
constant,
47+
constant_complex,
48+
constant_long,
49+
constant_ulong,
50+
create_constant_array,
51+
)
5152

5253
__all__ += ["diag_create", "diag_extract"]
5354

@@ -539,6 +540,7 @@
539540
"root",
540541
"rsqrt",
541542
"sqrt",
543+
"sigmoid",
542544
"tgamma",
543545
]
544546

@@ -558,6 +560,7 @@
558560
pow2,
559561
root,
560562
rsqrt,
563+
sigmoid,
561564
sqrt,
562565
tgamma,
563566
)

Diff for: ‎arrayfire_wrapper/lib/create_and_modify_array/create_array/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# flake8: noqa
2-
__all__ = ["constant", "constant_complex", "constant_long", "constant_ulong"]
2+
__all__ = ["constant", "constant_complex", "constant_long", "constant_ulong", "create_constant_array"]
33

4-
from .constant import constant, constant_complex, constant_long, constant_ulong
4+
from .constant import constant, constant_complex, constant_long, constant_ulong, create_constant_array
55

66
__all__ += [
77
"create_random_engine",

Diff for: ‎arrayfire_wrapper/lib/create_and_modify_array/create_array/constant.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ctypes
22

33
from arrayfire_wrapper.defines import AFArray, CShape
4-
from arrayfire_wrapper.dtypes import Dtype
4+
from arrayfire_wrapper.dtypes import Dtype, complex32, complex64, implicit_dtype, int64, uint64
55
from arrayfire_wrapper.lib._utility import call_from_clib
66

77

@@ -76,3 +76,19 @@ def constant_ulong(number: int | float, shape: tuple[int, ...], dtype: Dtype, /)
7676
ctypes.pointer(c_shape.c_array),
7777
)
7878
return out
79+
80+
81+
def create_constant_array(number: int | float | complex, shape: tuple[int, ...], dtype: Dtype, /) -> AFArray:
82+
if not dtype:
83+
dtype = implicit_dtype(number, dtype)
84+
85+
if isinstance(number, complex):
86+
return constant_complex(number, shape, dtype if dtype in {complex32, complex64} else complex32)
87+
88+
if dtype == int64:
89+
return constant_long(number, shape, dtype)
90+
91+
if dtype == uint64:
92+
return constant_ulong(number, shape, dtype)
93+
94+
return constant(number, shape, dtype)

Diff for: ‎arrayfire_wrapper/lib/mathematical_functions/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
pow2,
2626
root,
2727
rsqrt,
28+
sigmoid,
2829
sqrt,
2930
tgamma,
3031
)
@@ -46,6 +47,7 @@
4647
"root",
4748
"rsqrt",
4849
"sqrt",
50+
"sigmoid",
4951
"tgamma",
5052
]
5153

Diff for: ‎arrayfire_wrapper/lib/mathematical_functions/exp_and_log_functions.py

+7
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ def sqrt(arr: AFArray, /) -> AFArray:
114114
return unary_op(sqrt.__name__, arr)
115115

116116

117+
def sigmoid(arr: AFArray, /) -> AFArray:
118+
"""
119+
source: https://arrayfire.org/docs/group__arith__func__sigmoid.htm#gadf4280e3283b65264de75194e0a6d565
120+
"""
121+
return unary_op(sigmoid.__name__, arr)
122+
123+
117124
def tgamma(arr: AFArray, /) -> AFArray:
118125
"""
119126
source:

Diff for: ‎arrayfire_wrapper/lib/mathematical_functions/numeric_functions.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import ctypes
22

33
from arrayfire_wrapper.defines import AFArray
4+
from arrayfire_wrapper.dtypes import float32
45
from arrayfire_wrapper.lib._utility import binary_op, call_from_clib, unary_op
6+
from arrayfire_wrapper.lib.create_and_modify_array.create_array import create_constant_array
57
from arrayfire_wrapper.lib.mathematical_functions.arithmetic_operations import sub
68

79

@@ -46,7 +48,9 @@ def hypot(lhs: AFArray, rhs: AFArray, /) -> AFArray:
4648
"""
4749
source:
4850
"""
49-
return binary_op(hypot.__name__, lhs, rhs)
51+
out = AFArray.create_null_pointer()
52+
call_from_clib(hypot.__name__, lhs, rhs)
53+
return out
5054

5155

5256
def maxof(lhs: AFArray, rhs: AFArray, /) -> AFArray:
@@ -71,7 +75,7 @@ def mod(lhs: AFArray, rhs: AFArray, /) -> AFArray:
7175

7276

7377
def neg(arr: AFArray) -> AFArray:
74-
return sub(AFArray(0), arr)
78+
return sub(create_constant_array(0, (1,), float32), arr)
7579

7680

7781
def rem(lhs: AFArray, rhs: AFArray, /) -> AFArray:

Diff for: ‎arrayfire_wrapper/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22

33
_MAJOR = "0"
4-
_MINOR = "4"
4+
_MINOR = "5"
55
# On main and in a nightly release the patch should be one ahead of the last
66
# released build.
77
_PATCH = "0"

0 commit comments

Comments
 (0)
Please sign in to comment.