Skip to content

Commit

Permalink
gather changes
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Nov 13, 2023
1 parent 7011809 commit 45668fd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 10 deletions.
25 changes: 25 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,31 @@ def aten_ops_native_group_norm(
)


@dynamo_tensorrt_converter(torch.ops.aten.gather)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
def aten_ops_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.gather(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
dim=args[1],
index=args[2],
sparse_grad = args_bounds_check(args, 4, False),
)


@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default)
@dynamo_tensorrt_converter(torch.ops.aten.group_norm)
@enforce_tensor_types(
Expand Down
27 changes: 17 additions & 10 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,26 @@ def select(
indices_tensor = ctx.net.add_constant(
index_value.shape, to_numpy(index_value)
).get_output(0)
layer = ctx.net.add_gather(input, indices_tensor, dim)
out = layer.get_output(0)
out = gather(input, indices_tensor, dim)
if len(out.shape) != 1:
layer = ctx.net.add_shuffle(out)
return layer.get_output(0)


def gather(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int,
index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
) -> TRTTensor:
gather_layer = ctx.net.add_gather(input, index, dim)
set_layer_name(gather_layer, target, name + "_gather", source_ir)
return gather_layer.get_output(0)


def index(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -127,9 +140,7 @@ def index(
)
index = adv_indx_indices[0]
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
gather_layer = ctx.net.add_gather(input, indices_tensor, index)
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
return gather_layer.get_output(0)
return gather(input, index, indices_tensor)
else:
input_shape = input.shape
_LOGGER.debug(f"The input shape is {input.shape}")
Expand Down Expand Up @@ -242,11 +253,7 @@ def index(
dim_tensor_list[adv_indx_indices[i]],
)

gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0)
set_layer_name(
gather_layer_element, target, name + "_index_gather_element", source_ir
)
gather_out = gather_layer_element.get_output(0)
gather_out = gather(flatten_tensor, cum_adv_index, 0)
_LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}")
_LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}")

Expand Down

0 comments on commit 45668fd

Please sign in to comment.