Skip to content

Commit

Permalink
feat: lowering replace aten.full_like with aten.full (#3077)
Browse files Browse the repository at this point in the history
Co-authored-by: Dheeraj Peri <[email protected]>
  • Loading branch information
chohk88 and peri044 authored Aug 21, 2024
1 parent 6a38648 commit 7d0f540
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 19 deletions.
7 changes: 6 additions & 1 deletion core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,12 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}
}

auto current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart
auto current_device_id = -1;
if (inputs.size() > 0) {
current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart
} else if (outputs.size() > 0) {
current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart
}

compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id);
if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) {
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3861,4 +3861,5 @@ def aten_ops_full(
name,
shape=args[0],
fill_value=args[1],
dtype=kwargs.get("dtype", None),
)
31 changes: 25 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/impl/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt import _enums
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
Expand All @@ -20,23 +22,39 @@ def full(
name: str,
shape: Union[List[int], TRTTensor],
fill_value: Union[int, float, bool],
dtype: Union[torch.dtype, trt.DataType] = None,
) -> TRTTensor:
fill_value_tensor = torch.tensor(fill_value)
if dtype is None:
output_dtype = _enums.dtype._from(fill_value_tensor.dtype)
else:
output_dtype = _enums.dtype._from(dtype)
# in static shape scenario, shape is a list of int
if isinstance(shape, List):
# in static shape scenario, shape is a list of int
if all(isinstance(dim, int) for dim in shape):
return np.full(shape, fill_value)
output_np_dtype = output_dtype.try_to(np.dtype, use_default=True)
return np.full(shape, fill_value, dtype=output_np_dtype)
else:
shape = impl.cat.cat(
ctx, target, source_ir, name + "_concat_shape", shape, 0
)

# in dynamic shape scenario, shape is a shap tensor
# in dynamic shape scenario, shape is a shape tensor
# use IFillLayer to fill the shape tensor with LINSPACE value
layer = ctx.net.add_fill(shape.shape, trt.FillOperation.LINSPACE, shape.dtype)
layer = ctx.net.add_fill(
shape.shape, trt.FillOperation.LINSPACE, trt.DataType.INT32
)
layer.set_input(0, shape)
layer.set_input(1, get_trt_tensor(ctx, 0, name + "_start", min_rank=0))
delta = get_trt_tensor(ctx, 1, name + "_delta")
layer.set_input(
1, get_trt_tensor(ctx, 0, name + "_start", dtype=trt.DataType.INT32, min_rank=0)
)
delta = get_trt_tensor(
ctx,
1,
name + "_delta",
dtype=trt.DataType.INT32,
)
input = []
for _ in range(shape.shape[0]):
input.append(delta)
Expand All @@ -46,7 +64,8 @@ def full(

# fill the output tensor with the actual fill_value
output = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", output, 0)
if isinstance(fill_value, (int, float)):
# https://stackoverflow.com/questions/37888620/comparing-boolean-and-int-using-isinstance
if type(fill_value) in (int, float):
if isinstance(fill_value, float):
output = cast_trt_tensor(
ctx, output, trt.float32, name + "_casted", target, source_ir
Expand Down
14 changes: 9 additions & 5 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def var_decomposition(
@register_torch_trt_decomposition(
torch.ops.aten.empty_permuted.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: # type: ignore
empty_size = args[0]
empty_permute = args[1]
perm = [0] * len(empty_size)
Expand All @@ -188,7 +188,7 @@ def slice_scatter_decomposition(
start: Optional[int] = None,
end: Optional[int] = None,
step: Optional[int] = None,
):
) -> torch.Tensor:
dim_size = input_tensor.shape[dim]
start = get_positive_dim(start, input_tensor.shape[dim])
if end is None:
Expand All @@ -197,6 +197,11 @@ def slice_scatter_decomposition(
if step is None:
step = 1

# Ensure start, end, and step are all integers
assert isinstance(start, int), "start must be an integer"
assert isinstance(end, int), "end must be an integer"
assert isinstance(step, int), "step must be an integer"

src_dim = src_tensor.shape
# step == 0 is not a valid torch case
# also src_dim should be equal to slice dimension
Expand Down Expand Up @@ -233,7 +238,7 @@ def select_scatter_decomposition(
@register_torch_trt_decomposition(
torch.ops.aten.empty_strided.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: # type: ignore
empty_size = args[0]
empty_stride = args[1]
return torch.as_strided(
Expand All @@ -256,8 +261,7 @@ def scatter_add_decomposition(
src_shape = list(src_tensor.shape)
src_dim = src_shape[dim]
for i in range(0, src_dim):
to_scatter_tensor = torch.zeros_like(input_tensor)

to_scatter_tensor = torch.zeros(input_tensor.shape, dtype=input_tensor.dtype)
# index and src slice
src_slice = torch.select(src_tensor, dim, i)
index_slice = torch.select(index, dim, i)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .remove_detach import remove_detach
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
from .replace_full_like_with_full import replace_full_like_with_full
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .view_to_reshape import view_to_reshape

Expand All @@ -23,6 +24,7 @@
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
replace_full_like_with_full,
view_to_reshape,
]
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import logging

import torch
import torch.fx
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
from torch_tensorrt.dynamo.utils import to_torch_device

logger = logging.getLogger(__name__)


def replace_full_like_with_full(
gm: torch.fx.GraphModule,
) -> torch.fx.GraphModule:
"""Replace full_like nodes with equivalent full nodes"""
modified_graph = False

for node in gm.graph.nodes:
if node.target == torch.ops.aten.full_like.default:
modified_graph = True

# Extract arguments from full_like
input_tensor = node.args[0]
fill_value = node.args[1]
input_dtype = None
input_shape = None
input_device = to_torch_device(default_device())
if "val" in input_tensor.meta:
input_dtype = input_tensor.meta["val"].dtype
input_device = input_tensor.meta["val"].device
input_shape = list(input_tensor.meta["val"].shape)
elif "tensor_meta" in input_tensor.meta:
input_dtype = input_tensor.meta["tensor_meta"].dtype
input_shape = list(input_tensor.meta["tensor_meta"].shape)

# There's no memory format argument for torch.full.
# Set the input_device and dtype correspondingly.
new_kwargs = {}
for key, val in node.kwargs.items():
if key != "memory_format":
new_kwargs[key] = val
new_kwargs["device"] = input_device
new_kwargs["dtype"] = input_dtype
# Replace full_like with full, using the shape as a list
input_nodes = (input_shape, fill_value)
with gm.graph.inserting_after(node):
full_node = gm.graph.call_function(
torch.ops.aten.full.default,
args=input_nodes,
kwargs=new_kwargs,
)
full_node.meta = node.meta

node.replace_all_uses_with(full_node)
gm.graph.erase_node(node)

if modified_graph:
gm = clean_up_graph_after_modifications(gm)

return gm
68 changes: 61 additions & 7 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,66 @@ def forward(self, x):
f"MaxPool3d TRT outputs don't match with the original model.",
)

def test_lowering_full_like_module(self):
class FullLike(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x):
y = torch.full_like(x, 2.0)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.full.default}
unexpected_ops = {torch.ops.aten.full_like.default}

inputs = [torch.randn(3, 3, dtype=torch.float32).cuda()]

fx_graph = torch.fx.symbolic_trace(FullLike())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"FullLike TRT outputs don't match with the original model.",
)

def test_lowering_empty_like_module(self):
class emptyLike(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -976,7 +1036,7 @@ def forward(self, input):
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(),
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(),
{torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src},
{torch.ops.aten.add.Tensor},
),
(
"scatter_add_one_dim_indexOne_constant",
Expand All @@ -985,8 +1045,6 @@ def forward(self, input):
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(),
{
torch.ops.aten.add.Tensor,
torch.ops.aten.scatter.src,
torch.ops.aten.full_like.default,
},
),
(
Expand All @@ -996,8 +1054,6 @@ def forward(self, input):
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(),
{
torch.ops.aten.add.Tensor,
torch.ops.aten.scatter.src,
torch.ops.aten.full_like.default,
},
),
(
Expand All @@ -1009,8 +1065,6 @@ def forward(self, input):
).cuda(),
{
torch.ops.aten.add.Tensor,
torch.ops.aten.scatter.src,
torch.ops.aten.full_like.default,
},
),
]
Expand Down

0 comments on commit 7d0f540

Please sign in to comment.