@@ -73,6 +73,7 @@ def test_take_along_axis(x, data):
7373 # TODO
7474 # 2. negative indices
7575 # 3. different dtypes for indices
76+ # 4. "broadcast-compatible" indices
7677 axis = data .draw (
7778 st .integers (- x .ndim , max (x .ndim - 1 , 0 )) | st .none (),
7879 label = "axis"
@@ -84,8 +85,8 @@ def test_take_along_axis(x, data):
8485 axis_kw = {"axis" : axis }
8586 n_axis = axis + x .ndim if axis < 0 else axis
8687
87- len_axis = data .draw (st .integers (0 , 2 * x .shape [n_axis ]), label = "len_axis " )
88- idx_shape = x .shape [:n_axis ] + (len_axis ,) + x .shape [n_axis + 1 :]
88+ new_len = data .draw (st .integers (0 , 2 * x .shape [n_axis ]), label = "new_len " )
89+ idx_shape = x .shape [:n_axis ] + (new_len ,) + x .shape [n_axis + 1 :]
8990 indices = data .draw (
9091 hh .arrays (
9192 shape = idx_shape ,
@@ -102,7 +103,7 @@ def test_take_along_axis(x, data):
102103 ph .assert_shape (
103104 "take_along_axis" ,
104105 out_shape = out .shape ,
105- expected = x .shape [:n_axis ] + (len_axis ,) + x .shape [n_axis + 1 :],
106+ expected = x .shape [:n_axis ] + (new_len ,) + x .shape [n_axis + 1 :],
106107 kw = dict (
107108 x = x ,
108109 indices = indices ,
@@ -117,5 +118,5 @@ def test_take_along_axis(x, data):
117118 a_1d = x [ii + (slice (None ),) + kk ]
118119 i_1d = indices [ii + (slice (None ),) + kk ]
119120 o_1d = out [ii + (slice (None ),) + kk ]
120- for j in range (len_axis ):
121+ for j in range (new_len ):
121122 assert o_1d [j ] == a_1d [i_1d [j ]], f'{ ii = } , { kk = } , { j = } '
0 commit comments