@@ -68,7 +68,7 @@ def select(
68
68
indices_tensor = ctx .net .add_constant (
69
69
index_value .shape , to_numpy (index_value )
70
70
).get_output (0 )
71
- out = gather (input , indices_tensor , dim )
71
+ out = gather (ctx , target , source_ir , name , input , indices_tensor , dim )
72
72
if len (out .shape ) != 1 :
73
73
layer = ctx .net .add_shuffle (out )
74
74
return layer .get_output (0 )
@@ -140,7 +140,7 @@ def index(
140
140
)
141
141
index = adv_indx_indices [0 ]
142
142
_LOGGER .debug (f"The advanced index indices is { adv_indx_indices } " )
143
- return gather (input , index , indices_tensor )
143
+ return gather (ctx , target , source_ir , name , input , index , indices_tensor )
144
144
else :
145
145
input_shape = input .shape
146
146
_LOGGER .debug (f"The input shape is { input .shape } " )
@@ -253,7 +253,7 @@ def index(
253
253
dim_tensor_list [adv_indx_indices [i ]],
254
254
)
255
255
256
- gather_out = gather (flatten_tensor , cum_adv_index , 0 )
256
+ gather_out = gather (ctx , target , source_ir , name , flatten_tensor , 0 , cum_adv_index )
257
257
_LOGGER .debug (f"The shape after cumultative gather is { gather_out .shape } " )
258
258
_LOGGER .debug (f"The shape for cumulative adv index is { cum_adv_index } " )
259
259
0 commit comments