Skip to content

Commit

Permalink
feat: a lowering pass to re-compose ops into aten.linear (#2411)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Nov 9, 2023
1 parent 7029e91 commit da90d61
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 117 deletions.
28 changes: 28 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,3 +1957,31 @@ def aten_ops_argmax(
dim=args_bounds_check(args, 1),
keep_dim=args_bounds_check(args, 2, False),
)


@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (np.ndarray, torch.Tensor, TRTTensor),
2: (np.ndarray, torch.Tensor, TRTTensor),
}
) # type: ignore[misc]
def aten_ops_addmm(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.addmm.addmm(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
beta=kwargs.get("beta", 1),
alpha=kwargs.get("alpha", 1),
)
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from . import (
activation,
attention,
addmm,
argmax,
attention,
cast,
cat,
condition,
Expand Down
34 changes: 34 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/addmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional, Union

import numpy as np
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor


def addmm(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
mat1: Union[TRTTensor, torch.Tensor, np.ndarray],
mat2: Union[TRTTensor, torch.Tensor, np.ndarray],
*,
beta: Union[float, int],
alpha: Union[float, int],
) -> TRTTensor:
mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2)
if alpha != 1:
mm = impl.elementwise.mul(
ctx, target, SourceIR.ATEN, f"{name}_mul_alpha", mm, alpha
)
if beta != 1:
input = impl.elementwise.mul(
ctx, target, SourceIR.ATEN, f"{name}_mul_beta", input, beta
)

return impl.elementwise.add(ctx, target, source_ir, f"{name}_add", input, mm)
16 changes: 0 additions & 16 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,6 @@ def alias_replacement(x: torch.Tensor) -> torch.Tensor:
return x


@register_torch_trt_decomposition(
torch.ops.aten.addmm, registry=TORCH_TRT_DECOMPOSITIONS
)
def addmm_replacement(
input_: torch.Tensor,
mat1: torch.Tensor,
mat2: torch.Tensor,
*,
beta: int = 1,
alpha: int = 1,
) -> torch.Tensor:
return torch.add(
torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha)
)


@register_torch_trt_decomposition(
torch.ops.aten.reciprocal.default, registry=TORCH_TRT_DECOMPOSITIONS
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_efficient_attention import lower_efficient_attention
from .lower_linear import lower_linear
from .pass_manager import DynamoPassManager
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
Expand All @@ -17,6 +18,7 @@
constant_fold,
repair_input_as_output,
lower_efficient_attention,
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
get_tensor_placeholders,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,34 +35,13 @@ def efficient_attention_replacement() -> (
):
"""Constructs the original and replacement functions for efficient attention"""

# Empty boilerplate function taking in three Tensors and returning one
def boilerplate(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
...

# Trace boilerplate function and extract placeholder and output nodes
orig = torch.fx.symbolic_trace(boilerplate)
q, k, v = get_tensor_placeholders(orig)
output = [node for node in orig.graph.nodes if node.op == "output"][0]

# Graph types to replace are those which use the _scaled_dot_product_efficient_attention
# function and extract only the first element
with orig.graph.inserting_before(output):
att = orig.graph.call_function(
torch.ops.aten._scaled_dot_product_efficient_attention.default,
args=(q, k, v, None, False),
# Original graph
def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
q, k, v, None, False
)
out = orig.graph.call_function(
operator.getitem,
args=(att, 0),
)

# Assign the output of the graph to be the single getitem output
output.args = (out,)

orig.graph.lint()
orig.recompile()
out = operator.getitem(outputs, 0)
return out

# Replacement graph consists of the functional version of scaled_dot_product_attention
def replacement(
Expand Down
47 changes: 47 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
from typing import Callable, Sequence, Tuple

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def lower_linear(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace aten.linear with an equivalent implementation which can be easily converted to TRT"""
orig, replacement = linear_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after lowering linear:\n{gm.graph}")

return gm


def linear_replacement() -> (
Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
):
"""Constructs the original and replacement functions for linear"""

# Original graph
def orig(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
W_T = torch.ops.aten.permute.default(weight, [1, 0])
out = torch.ops.aten.addmm.default(bias, input, W_T)
return out

# Replacement graph
def replacement(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
return torch.ops.aten.linear.default(input, weight, bias)

return orig, replacement
65 changes: 65 additions & 0 deletions tests/py/dynamo/conversion/test_addmm_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestAddmmConverter(DispatchTestCase):
@parameterized.expand(
[
((2, 2), (2, 3), (3, 2)),
((4, 6), (4, 5), (5, 6)),
((2, 1), (2, 3), (3, 1)),
((4, 1), (4, 1), (1, 1)),
((1, 2), (1, 3), (3, 2)),
]
)
def test_addmm(self, input_shape, mat1_shape, mat2_shape):
class Addmm(nn.Module):
def forward(self, input, mat1, mat2):
return torch.ops.aten.addmm.default(input, mat1, mat2)

inputs = [
torch.randn(input_shape),
torch.randn(mat1_shape),
torch.randn(mat2_shape),
]

self.run_test(
Addmm(),
inputs,
)

@parameterized.expand(
[
((2, 2), (2, 3), (3, 2), 1.0, 1.0),
((4, 6), (4, 5), (5, 6), 1.2, 0.8),
((2, 1), (2, 3), (3, 1), 3, 2),
((4, 1), (4, 1), (1, 1), 1, 1),
((1, 2), (1, 3), (3, 2), 2, 1.0),
((1, 2), (1, 3), (3, 2), 1, 2.0),
]
)
def test_addmm_scale(self, input_shape, mat1_shape, mat2_shape, beta, alpha):
class Addmm(nn.Module):
def forward(self, input, mat1, mat2):
return torch.ops.aten.addmm.default(
input, mat1, mat2, beta=beta, alpha=alpha
)

inputs = [
torch.randn(input_shape),
torch.randn(mat1_shape),
torch.randn(mat2_shape),
]

self.run_test(
Addmm(),
inputs,
)


if __name__ == "__main__":
run_tests()
108 changes: 108 additions & 0 deletions tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,5 +267,113 @@ def forward(self, q, k, v):
torch._dynamo.reset()


class TestLowerLinear(TestCase):
def test_lower_linear(self):
class Linear(torch.nn.Module):
def forward(self, input, weight, bias):
out = torch.ops.aten.linear.default(input, weight, bias)
return out

inputs = [
torch.rand((3, 32)).cuda(),
torch.rand((64, 32)).cuda(),
torch.rand((64,)).cuda(),
]

fx_graph = torch.fx.symbolic_trace(Linear())
expected_ops = {torch.ops.aten.linear.default}
unexpected_ops = {
torch.ops.aten.permute.default,
torch.ops.aten.addmm.default,
}

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)
torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = torch.cat(
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
)
torch_model_results = torch.cat(
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"Linear TRT outputs don't match with the original model.",
)
torch._dynamo.reset()

def test_lower_linear_batch(self):
class Linear(torch.nn.Module):
def forward(self, input, weight, bias):
out = torch.ops.aten.linear.default(input, weight, bias)
return out

inputs = [
torch.rand((2, 2, 32)).cuda(),
torch.rand((64, 32)).cuda(),
torch.rand((64,)).cuda(),
]

fx_graph = torch.fx.symbolic_trace(Linear())

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = torch.cat(
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
)
torch_model_results = torch.cat(
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"Linear TRT outputs don't match with the original model.",
)
torch._dynamo.reset()


if __name__ == "__main__":
run_tests()
Loading

0 comments on commit da90d61

Please sign in to comment.