Skip to content

Commit e593c9b

Browse files
committed
changing gather signature
1 parent c4ae041 commit e593c9b

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def aten_ops_native_group_norm(
183183
)
184184

185185

186-
@dynamo_tensorrt_converter(torch.ops.aten.gather)
186+
@dynamo_tensorrt_converter(torch.ops.aten.gather.default)
187187
@enforce_tensor_types(
188188
{
189189
0: (TRTTensor,),

py/torch_tensorrt/dynamo/conversion/impl/select.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def select(
6868
indices_tensor = ctx.net.add_constant(
6969
index_value.shape, to_numpy(index_value)
7070
).get_output(0)
71-
out = gather(input, indices_tensor, dim)
71+
out = gather(ctx, target, source_ir, name, input, indices_tensor, dim)
7272
if len(out.shape) != 1:
7373
layer = ctx.net.add_shuffle(out)
7474
return layer.get_output(0)
@@ -140,7 +140,7 @@ def index(
140140
)
141141
index = adv_indx_indices[0]
142142
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
143-
return gather(input, index, indices_tensor)
143+
return gather(ctx, target, source_ir, name, input, index, indices_tensor)
144144
else:
145145
input_shape = input.shape
146146
_LOGGER.debug(f"The input shape is {input.shape}")
@@ -253,7 +253,7 @@ def index(
253253
dim_tensor_list[adv_indx_indices[i]],
254254
)
255255

256-
gather_out = gather(flatten_tensor, cum_adv_index, 0)
256+
gather_out = gather(ctx, target, source_ir, name, flatten_tensor, 0, cum_adv_index)
257257
_LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}")
258258
_LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}")
259259

0 commit comments

Comments
 (0)