Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gather Implementation #2457

Closed
wants to merge 11 commits into from
25 changes: 25 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,31 @@ def aten_ops_native_group_norm(
)


@dynamo_tensorrt_converter(torch.ops.aten.gather.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The # type: ignore can be removed.

def aten_ops_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.gather(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
dim=args[1],
index=args[2],
sparse_grad=args_bounds_check(args, 3, False),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the schema here, sparse_grad would be in kwargs. Additionally, since it seems to have no effect in the gather converter below, it can be removed/ignored, or a validator can be used to ensure it is False.

)


@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default)
@dynamo_tensorrt_converter(torch.ops.aten.group_norm)
@enforce_tensor_types(
Expand Down
39 changes: 28 additions & 11 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
broadcastable,
cast_trt_tensor,
get_positive_dim,
get_trt_tensor,
to_numpy,
Expand Down Expand Up @@ -68,11 +69,31 @@ def select(
indices_tensor = ctx.net.add_constant(
index_value.shape, to_numpy(index_value)
).get_output(0)
layer = ctx.net.add_gather(input, indices_tensor, dim)
out = layer.get_output(0)
return gather(ctx, target, source_ir, name, input, dim, indices_tensor)


def gather(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int,
index: Union[TRTTensor, np.ndarray, torch.Tensor],
sparse_grad: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above, can remove if never used.

) -> TRTTensor:
if not isinstance(index, TRTTensor):
index = get_trt_tensor(ctx, index, name + f"_parameter_to_fp32_tensor")
# This is for the case where torch.ops.aten.gather requires torch.int64
# However TRTInterpreter complains that torch.int64 is not a supported type
# So the below cast does not help
# index = cast_trt_tensor(ctx, input, trt.int32, name, target, source_ir)
Comment on lines +86 to +89
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this issue still occur in the test cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it does. aten.scatter has similar use cases, so I am working on that. The casting of nodes in the TRT test infrastructure can be used here to get over. This is the PR- #2664.

gather_layer = ctx.net.add_gather(input, index, dim)
set_layer_name(gather_layer, target, name + "_gather", source_ir)
out = gather_layer.get_output(0)
if len(out.shape) != 1:
layer = ctx.net.add_shuffle(out)
return layer.get_output(0)
gather_layer = ctx.net.add_shuffle(out)
return gather_layer.get_output(0)


def index(
Expand Down Expand Up @@ -127,9 +148,7 @@ def index(
)
index = adv_indx_indices[0]
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
gather_layer = ctx.net.add_gather(input, indices_tensor, index)
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
return gather_layer.get_output(0)
return gather(ctx, target, source_ir, name, input, index, indices_tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 126/147 needs to be renamed since it uses name + f"_parameter_to_fp32_tensor", which also appears in the gather function. This could cause a duplicate name error in edge cases

else:
input_shape = input.shape
_LOGGER.debug(f"The input shape is {input.shape}")
Expand Down Expand Up @@ -242,11 +261,9 @@ def index(
dim_tensor_list[adv_indx_indices[i]],
)

gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0)
set_layer_name(
gather_layer_element, target, name + "_index_gather_element", source_ir
gather_out = gather(
ctx, target, source_ir, name, flatten_tensor, 0, cum_adv_index
)
gather_out = gather_layer_element.get_output(0)
_LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}")
_LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}")

Expand Down
Loading