1010 _real_numeric_dtypes ,
1111 _numeric_dtypes ,
1212 _result_type ,
13- _dtype_categories as _dtype_dtype_categories ,
13+ _dtype_categories ,
1414)
1515from ._array_object import Array
1616from ._flags import requires_api_version
@@ -51,6 +51,22 @@ def inner(x1: Array, x2: Array, /) -> Array:
5151 return inner
5252
5353
54+
55+ # static type annotation for ArrayOrPythonScalar arguments given a category
56+ # NB: keep the keys in sync with the _dtype_categories dict
57+ _annotations = {
58+ "all" : "bool | int | float | complex | Array" ,
59+ "real numeric" : "int | float | Array" ,
60+ "numeric" : "int | float | complex | Array" ,
61+ "integer" : "int | Array" ,
62+ "integer or boolean" : "int | bool | Array" ,
63+ "boolean" : "bool | Array" ,
64+ "real floating-point" : "float | Array" ,
65+ "complex floating-point" : "complex | Array" ,
66+ "floating-point" : "float | complex | Array" ,
67+ }
68+
69+
5470# func_name: dtype_category (must match that from _dtypes.py)
5571_binary_funcs = {
5672 "add" : "numeric" ,
@@ -97,7 +113,7 @@ def inner(x1: Array, x2: Array, /) -> Array:
97113# create and attach functions to the module
98114for func_name , dtype_category in _binary_funcs .items ():
99115 # sanity check
100- assert dtype_category in _dtype_dtype_categories
116+ assert dtype_category in _dtype_categories
101117
102118 numpy_name = _numpy_renames .get (func_name , func_name )
103119 np_func = getattr (np , numpy_name )
@@ -106,6 +122,8 @@ def inner(x1: Array, x2: Array, /) -> Array:
106122 func .__name__ = func_name
107123
108124 func .__doc__ = _binary_docstring_template % (numpy_name , numpy_name )
125+ func .__annotations__ ['x1' ] = _annotations [dtype_category ]
126+ func .__annotations__ ['x2' ] = _annotations [dtype_category ]
109127
110128 vars ()[func_name ] = func
111129
@@ -117,15 +135,15 @@ def inner(x1: Array, x2: Array, /) -> Array:
117135nextafter = requires_api_version ('2024.12' )(nextafter ) # noqa: F821
118136
119137
120- def bitwise_left_shift (x1 : Array , x2 : Array , / ) -> Array :
138+ def bitwise_left_shift (x1 : int | Array , x2 : int | Array , / ) -> Array :
121139 is_negative = np .any (x2 ._array < 0 ) if isinstance (x2 , Array ) else x2 < 0
122140 if is_negative :
123141 raise ValueError ("bitwise_left_shift(x1, x2) is only defined for x2 >= 0" )
124142 return _bitwise_left_shift (x1 , x2 ) # noqa: F821
125143bitwise_left_shift .__doc__ = _bitwise_left_shift .__doc__ # noqa: F821
126144
127145
128- def bitwise_right_shift (x1 : Array , x2 : Array , / ) -> Array :
146+ def bitwise_right_shift (x1 : int | Array , x2 : int | Array , / ) -> Array :
129147 is_negative = np .any (x2 ._array < 0 ) if isinstance (x2 , Array ) else x2 < 0
130148 if is_negative :
131149 raise ValueError ("bitwise_left_shift(x1, x2) is only defined for x2 >= 0" )
0 commit comments