-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gather Implementation #2457
Gather Implementation #2457
Changes from 10 commits
380b698
af10403
99e7de9
e194877
b253bd1
3d2ada6
4ad2791
a0ff737
b304233
9f751c2
a3e6586
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -183,6 +183,31 @@ def aten_ops_native_group_norm( | |
) | ||
|
||
|
||
@dynamo_tensorrt_converter(torch.ops.aten.gather.default) | ||
@enforce_tensor_types( | ||
{ | ||
0: (TRTTensor,), | ||
} | ||
) # type: ignore[misc] | ||
def aten_ops_gather( | ||
ctx: ConversionContext, | ||
target: Target, | ||
args: Tuple[Argument, ...], | ||
kwargs: Dict[str, Argument], | ||
name: str, | ||
) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
return impl.select.gather( | ||
ctx, | ||
target, | ||
SourceIR.ATEN, | ||
name, | ||
input=args[0], | ||
dim=args[1], | ||
index=args[2], | ||
sparse_grad=args_bounds_check(args, 3, False), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on the schema here, |
||
) | ||
|
||
|
||
@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) | ||
@dynamo_tensorrt_converter(torch.ops.aten.group_norm) | ||
@enforce_tensor_types( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext | ||
from torch_tensorrt.dynamo.conversion.converter_utils import ( | ||
broadcastable, | ||
cast_trt_tensor, | ||
get_positive_dim, | ||
get_trt_tensor, | ||
to_numpy, | ||
|
@@ -68,11 +69,31 @@ def select( | |
indices_tensor = ctx.net.add_constant( | ||
index_value.shape, to_numpy(index_value) | ||
).get_output(0) | ||
layer = ctx.net.add_gather(input, indices_tensor, dim) | ||
out = layer.get_output(0) | ||
return gather(ctx, target, source_ir, name, input, dim, indices_tensor) | ||
|
||
|
||
def gather( | ||
ctx: ConversionContext, | ||
target: Target, | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
dim: int, | ||
index: Union[TRTTensor, np.ndarray, torch.Tensor], | ||
sparse_grad: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above, can remove if never used. |
||
) -> TRTTensor: | ||
if not isinstance(index, TRTTensor): | ||
index = get_trt_tensor(ctx, index, name + f"_parameter_to_fp32_tensor") | ||
# This is for the case where torch.ops.aten.gather requires torch.int64 | ||
# However TRTInterpreter complains that torch.int64 is not a supported type | ||
# So the below cast does not help | ||
# index = cast_trt_tensor(ctx, input, trt.int32, name, target, source_ir) | ||
Comment on lines
+86
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this issue still occur in the test cases? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it does. aten.scatter has similar use cases, so I am working on that. The casting of nodes in the TRT test infrastructure can be used here to get over. This is the PR- #2664. |
||
gather_layer = ctx.net.add_gather(input, index, dim) | ||
set_layer_name(gather_layer, target, name + "_gather", source_ir) | ||
out = gather_layer.get_output(0) | ||
if len(out.shape) != 1: | ||
layer = ctx.net.add_shuffle(out) | ||
return layer.get_output(0) | ||
gather_layer = ctx.net.add_shuffle(out) | ||
return gather_layer.get_output(0) | ||
|
||
|
||
def index( | ||
|
@@ -127,9 +148,7 @@ def index( | |
) | ||
index = adv_indx_indices[0] | ||
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}") | ||
gather_layer = ctx.net.add_gather(input, indices_tensor, index) | ||
set_layer_name(gather_layer, target, name + "_index_gather", source_ir) | ||
return gather_layer.get_output(0) | ||
return gather(ctx, target, source_ir, name, input, index, indices_tensor) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line |
||
else: | ||
input_shape = input.shape | ||
_LOGGER.debug(f"The input shape is {input.shape}") | ||
|
@@ -242,11 +261,9 @@ def index( | |
dim_tensor_list[adv_indx_indices[i]], | ||
) | ||
|
||
gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0) | ||
set_layer_name( | ||
gather_layer_element, target, name + "_index_gather_element", source_ir | ||
gather_out = gather( | ||
ctx, target, source_ir, name, flatten_tensor, 0, cum_adv_index | ||
) | ||
gather_out = gather_layer_element.get_output(0) | ||
_LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}") | ||
_LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
# type: ignore
can be removed.