Skip to content

Commit

Permalink
2023-11-09 nightly release (da90d61)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Nov 9, 2023
1 parent 20667b9 commit 6455e2b
Show file tree
Hide file tree
Showing 13 changed files with 385 additions and 160 deletions.
28 changes: 28 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,3 +1957,31 @@ def aten_ops_argmax(
dim=args_bounds_check(args, 1),
keep_dim=args_bounds_check(args, 2, False),
)


@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (np.ndarray, torch.Tensor, TRTTensor),
2: (np.ndarray, torch.Tensor, TRTTensor),
}
) # type: ignore[misc]
def aten_ops_addmm(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.addmm.addmm(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
beta=kwargs.get("beta", 1),
alpha=kwargs.get("alpha", 1),
)
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,7 @@ def cast_int_int_div_trt_tensor(


def broadcastable(
a: TRTTensor,
b: TRTTensor,
a: Union[TRTTensor, np.ndarray], b: Union[TRTTensor, np.ndarray]
) -> bool:
"Check if two tensors are broadcastable according to torch rules"
a_shape = tuple(a.shape)
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from . import (
activation,
attention,
addmm,
argmax,
attention,
cast,
cat,
condition,
Expand Down
34 changes: 34 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/addmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional, Union

import numpy as np
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor


def addmm(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
mat1: Union[TRTTensor, torch.Tensor, np.ndarray],
mat2: Union[TRTTensor, torch.Tensor, np.ndarray],
*,
beta: Union[float, int],
alpha: Union[float, int],
) -> TRTTensor:
mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2)
if alpha != 1:
mm = impl.elementwise.mul(
ctx, target, SourceIR.ATEN, f"{name}_mul_alpha", mm, alpha
)
if beta != 1:
input = impl.elementwise.mul(
ctx, target, SourceIR.ATEN, f"{name}_mul_beta", input, beta
)

return impl.elementwise.add(ctx, target, source_ir, f"{name}_add", input, mm)
103 changes: 64 additions & 39 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand Down Expand Up @@ -80,23 +81,34 @@ def index(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
index: Union[TRTTensor, Sequence[TRTTensor]],
index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
# _LOGGER.debug(f"The index shape is {index.shape}")
# check if the input is dynamic
dynamic_shape = has_dynamic_shape(input.shape)

# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
# If any is not this flag will be set to False
_LOGGER.debug(
f"Determining whether aten.index constant-index optimization can be invoked"
)
is_numpy = all(
isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
)
# here we need to check if all the index are broadcastable
# if no, then we need to broadcast
last_index = None
for i, ind in enumerate(index):
if ind is not None:
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
adv_indx_indices.append(i)
# torch.nn.parameter.Parameter=> torch.Tensor
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
# torch.nn.parameter.Parameter=> numpy array
# numpy array is kept as numpy
# other cases are kept as TRTTensor
if is_numpy:
ind = to_numpy(ind)
else:
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
if last_index is not None:
assert broadcastable(
ind, last_index
Expand All @@ -110,8 +122,9 @@ def index(
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
return identity_layer.get_output(0)
elif len(tensor_indices) == 1:
# This case works
indices_tensor = tensor_indices[0]
indices_tensor = get_trt_tensor(
ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
)
index = adv_indx_indices[0]
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
gather_layer = ctx.net.add_gather(input, indices_tensor, index)
Expand Down Expand Up @@ -150,6 +163,7 @@ def index(
if i not in adv_indx_indices:
new_order.append(i)
_LOGGER.debug(f"The new transpose order is {new_order}")

transpose_layer.second_transpose = tuple(new_order)
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
transpose_tensor = transpose_layer.get_output(0)
Expand All @@ -175,47 +189,58 @@ def index(
concat_tensor = concat_tensor_layer.get_output(0)

reshape_layer = ctx.net.add_shuffle(transpose_tensor)
# check this
reshape_layer.set_input(1, concat_tensor)
flatten_tensor = reshape_layer.get_output(0)

_LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}")

# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
# // j dimension of input x.
multiplier = get_trt_tensor(
ctx,
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
name + "_dim_last",
)
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_intermediate_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
tensor_indices[i],
if is_numpy:
multiplier = input_shape[adv_indx_indices[adv_indx_count - 1]]
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = multiplier * tensor_indices[i]
cum_adv_index = cum_adv_index + adv_index
multiplier = multiplier * input_shape[adv_indx_indices[i]]
cum_adv_index = get_trt_tensor(
ctx, cum_adv_index, name + f"_index_sum_intermediate"
)
cum_adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_sum_intermediate_{i}",
trt.ElementWiseOperation.SUM,
cum_adv_index,
adv_index,
)
multiplier = convert_binary_elementwise(
else:
multiplier = get_trt_tensor(
ctx,
target,
source_ir,
name + f"_index_intermediate_xj_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
dim_tensor_list[adv_indx_indices[i]],
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
name + "_dim_last",
)
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_intermediate_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
tensor_indices[i],
)
cum_adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_sum_intermediate_{i}",
trt.ElementWiseOperation.SUM,
cum_adv_index,
adv_index,
)
multiplier = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_intermediate_xj_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
dim_tensor_list[adv_indx_indices[i]],
)

gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0)
set_layer_name(
Expand Down
16 changes: 0 additions & 16 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,6 @@ def alias_replacement(x: torch.Tensor) -> torch.Tensor:
return x


@register_torch_trt_decomposition(
torch.ops.aten.addmm, registry=TORCH_TRT_DECOMPOSITIONS
)
def addmm_replacement(
input_: torch.Tensor,
mat1: torch.Tensor,
mat2: torch.Tensor,
*,
beta: int = 1,
alpha: int = 1,
) -> torch.Tensor:
return torch.add(
torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha)
)


@register_torch_trt_decomposition(
torch.ops.aten.reciprocal.default, registry=TORCH_TRT_DECOMPOSITIONS
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_efficient_attention import lower_efficient_attention
from .lower_linear import lower_linear
from .pass_manager import DynamoPassManager
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
Expand All @@ -17,6 +18,7 @@
constant_fold,
repair_input_as_output,
lower_efficient_attention,
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
get_tensor_placeholders,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,34 +35,13 @@ def efficient_attention_replacement() -> (
):
"""Constructs the original and replacement functions for efficient attention"""

# Empty boilerplate function taking in three Tensors and returning one
def boilerplate(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
...

# Trace boilerplate function and extract placeholder and output nodes
orig = torch.fx.symbolic_trace(boilerplate)
q, k, v = get_tensor_placeholders(orig)
output = [node for node in orig.graph.nodes if node.op == "output"][0]

# Graph types to replace are those which use the _scaled_dot_product_efficient_attention
# function and extract only the first element
with orig.graph.inserting_before(output):
att = orig.graph.call_function(
torch.ops.aten._scaled_dot_product_efficient_attention.default,
args=(q, k, v, None, False),
# Original graph
def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
q, k, v, None, False
)
out = orig.graph.call_function(
operator.getitem,
args=(att, 0),
)

# Assign the output of the graph to be the single getitem output
output.args = (out,)

orig.graph.lint()
orig.recompile()
out = operator.getitem(outputs, 0)
return out

# Replacement graph consists of the functional version of scaled_dot_product_attention
def replacement(
Expand Down
47 changes: 47 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
from typing import Callable, Sequence, Tuple

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def lower_linear(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace aten.linear with an equivalent implementation which can be easily converted to TRT"""
orig, replacement = linear_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after lowering linear:\n{gm.graph}")

return gm


def linear_replacement() -> (
Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
):
"""Constructs the original and replacement functions for linear"""

# Original graph
def orig(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
W_T = torch.ops.aten.permute.default(weight, [1, 0])
out = torch.ops.aten.addmm.default(bias, input, W_T)
return out

# Replacement graph
def replacement(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
return torch.ops.aten.linear.default(input, weight, bias)

return orig, replacement
Loading

0 comments on commit 6455e2b

Please sign in to comment.