-
Notifications
You must be signed in to change notification settings - Fork 351
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: a lowering pass to re-compose ops into aten.linear (#2411)
- Loading branch information
Showing
10 changed files
with
292 additions
and
117 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,9 @@ | |
|
||
from . import ( | ||
activation, | ||
attention, | ||
addmm, | ||
argmax, | ||
attention, | ||
cast, | ||
cat, | ||
condition, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
from torch.fx.node import Target | ||
from torch_tensorrt.dynamo._SourceIR import SourceIR | ||
from torch_tensorrt.dynamo.conversion import impl | ||
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext | ||
from torch_tensorrt.fx.types import TRTTensor | ||
|
||
|
||
def addmm( | ||
ctx: ConversionContext, | ||
target: Target, | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
mat1: Union[TRTTensor, torch.Tensor, np.ndarray], | ||
mat2: Union[TRTTensor, torch.Tensor, np.ndarray], | ||
*, | ||
beta: Union[float, int], | ||
alpha: Union[float, int], | ||
) -> TRTTensor: | ||
mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2) | ||
if alpha != 1: | ||
mm = impl.elementwise.mul( | ||
ctx, target, SourceIR.ATEN, f"{name}_mul_alpha", mm, alpha | ||
) | ||
if beta != 1: | ||
input = impl.elementwise.mul( | ||
ctx, target, SourceIR.ATEN, f"{name}_mul_beta", input, beta | ||
) | ||
|
||
return impl.elementwise.add(ctx, target, source_ir, f"{name}_add", input, mm) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import logging | ||
from typing import Callable, Sequence, Tuple | ||
|
||
import torch | ||
from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( | ||
clean_up_graph_after_modifications, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def lower_linear( | ||
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] | ||
) -> torch.fx.GraphModule: | ||
"""Replace aten.linear with an equivalent implementation which can be easily converted to TRT""" | ||
orig, replacement = linear_replacement() | ||
|
||
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): | ||
gm = clean_up_graph_after_modifications(gm) | ||
logger.debug(f"Graph after lowering linear:\n{gm.graph}") | ||
|
||
return gm | ||
|
||
|
||
def linear_replacement() -> ( | ||
Tuple[ | ||
torch.fx.GraphModule, | ||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], | ||
] | ||
): | ||
"""Constructs the original and replacement functions for linear""" | ||
|
||
# Original graph | ||
def orig( | ||
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | ||
) -> torch.Tensor: | ||
W_T = torch.ops.aten.permute.default(weight, [1, 0]) | ||
out = torch.ops.aten.addmm.default(bias, input, W_T) | ||
return out | ||
|
||
# Replacement graph | ||
def replacement( | ||
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | ||
) -> torch.Tensor: | ||
return torch.ops.aten.linear.default(input, weight, bias) | ||
|
||
return orig, replacement |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
import torch.nn as nn | ||
from parameterized import parameterized | ||
from torch.testing._internal.common_utils import run_tests | ||
|
||
from .harness import DispatchTestCase | ||
|
||
|
||
class TestAddmmConverter(DispatchTestCase): | ||
@parameterized.expand( | ||
[ | ||
((2, 2), (2, 3), (3, 2)), | ||
((4, 6), (4, 5), (5, 6)), | ||
((2, 1), (2, 3), (3, 1)), | ||
((4, 1), (4, 1), (1, 1)), | ||
((1, 2), (1, 3), (3, 2)), | ||
] | ||
) | ||
def test_addmm(self, input_shape, mat1_shape, mat2_shape): | ||
class Addmm(nn.Module): | ||
def forward(self, input, mat1, mat2): | ||
return torch.ops.aten.addmm.default(input, mat1, mat2) | ||
|
||
inputs = [ | ||
torch.randn(input_shape), | ||
torch.randn(mat1_shape), | ||
torch.randn(mat2_shape), | ||
] | ||
|
||
self.run_test( | ||
Addmm(), | ||
inputs, | ||
) | ||
|
||
@parameterized.expand( | ||
[ | ||
((2, 2), (2, 3), (3, 2), 1.0, 1.0), | ||
((4, 6), (4, 5), (5, 6), 1.2, 0.8), | ||
((2, 1), (2, 3), (3, 1), 3, 2), | ||
((4, 1), (4, 1), (1, 1), 1, 1), | ||
((1, 2), (1, 3), (3, 2), 2, 1.0), | ||
((1, 2), (1, 3), (3, 2), 1, 2.0), | ||
] | ||
) | ||
def test_addmm_scale(self, input_shape, mat1_shape, mat2_shape, beta, alpha): | ||
class Addmm(nn.Module): | ||
def forward(self, input, mat1, mat2): | ||
return torch.ops.aten.addmm.default( | ||
input, mat1, mat2, beta=beta, alpha=alpha | ||
) | ||
|
||
inputs = [ | ||
torch.randn(input_shape), | ||
torch.randn(mat1_shape), | ||
torch.randn(mat2_shape), | ||
] | ||
|
||
self.run_test( | ||
Addmm(), | ||
inputs, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.