Skip to content

Commit 4476792

Browse files
authored
chore: bug fixes for full and expand (#3019)
1 parent c99c966 commit 4476792

File tree

7 files changed

+120
-33
lines changed

7 files changed

+120
-33
lines changed

core/runtime/execute_engine.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
114114
// Whether cudagraphs needs to record the graph on this pass
115115
bool need_cudagraphs_record = (CUDAGRAPHS_MODE && !_cudagraphs_validate_shapes(inputs, compiled_engine));
116116

117+
// this is a buffer to store shape tensor input addresses throughout the runtime scope
118+
std::list<std::vector<int32_t>> inputShapeTensorValues;
119+
117120
// Intialize outputs to be available throughout the succeeding scopes
118121
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
119122

@@ -177,8 +180,6 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
177180
}
178181
}
179182

180-
// this is a buffer to store shape tensor input addresses throughout the runtime scope
181-
std::list<std::vector<int32_t>> inputShapeTensorValues;
182183
{
183184
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
184185
if (compiled_engine->profile_execution) {
@@ -200,12 +201,12 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
200201
at::Tensor contig_input;
201202

202203
if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) {
203-
// Shape tensor inputs are casted to int32 explicitly.
204+
// Shape tensor inputs are casted to int64 explicitly.
204205
// Refer to
205206
// https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
206-
auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt32);
207-
std::vector<int32_t> inputs_cpu_vec(
208-
input_cpu.data_ptr<int32_t>(), input_cpu.data_ptr<int32_t>() + input_cpu.numel());
207+
auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt64);
208+
std::vector<int64_t> inputs_cpu_vec(
209+
input_cpu.data_ptr<int64_t>(), input_cpu.data_ptr<int64_t>() + input_cpu.numel());
209210
inputShapeTensorValues.emplace_back(inputs_cpu_vec);
210211
TORCHTRT_CHECK(
211212
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),

py/torch_tensorrt/dynamo/conversion/impl/full.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@ def full(
2323
) -> TRTTensor:
2424
# in static shape scenario, shape is a list of int
2525
if isinstance(shape, List):
26-
return np.full(shape, fill_value)
26+
# in static shape scenario, shape is a list of int
27+
if all(isinstance(dim, int) for dim in shape):
28+
return np.full(shape, fill_value)
29+
else:
30+
shape = impl.cat.cat(
31+
ctx, target, source_ir, name + "_concat_shape", shape, 0
32+
)
2733

2834
# in dynamic shape scenario, shape is a shap tensor
2935
# use IFillLayer to fill the shape tensor with LINSPACE value

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

+35-22
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def expand(
226226
) -> TRTTensor:
227227
shape_rank = len(shape)
228228
initial_tensor_rank = len(input_t.shape)
229+
229230
# If the rank of the input tensor is less than the shape's rank, pad with ones
230231
if initial_tensor_rank < shape_rank:
231232
input_t = prepend_ones(
@@ -244,39 +245,49 @@ def expand(
244245
# After the above padding, the shape and tensor rank must be equal
245246
assert len(input_t.shape) == shape_rank
246247

247-
shape_t = []
248-
for i in range(shape_rank):
249-
if shape[i] == -1:
250-
shape_t.append(
251-
get_shape(ctx, target, source_ir, name + f"_shape_dim{i}", input_t, i)
252-
)
253-
else:
254-
shape_t.append(shape[i])
255-
256-
# Establish the desired output shape, strides, and starting indices
257-
input_tensor_shape = tuple(input_t.shape)
248+
# Configure the start, strides and output shape tensors
258249
start = tuple([0] * shape_rank)
259250

260-
# TODO: Revisit stride calculation. stride[dim]=0 implies that dimension is being broadcasted.
251+
# stride[dim]=0 implies that dimension is being broadcasted.
261252
# stride should be 1 for all non-broadcasted dims
262253
stride = []
263-
for i, o in zip(input_tensor_shape, shape_t):
264-
# If the shape has ITensor, we treat it as a reshape dim instead of a broadcasted dim
265-
# shape_t cannot have -1. If the input at this dimension has a shape of -1, set the stride to 1. This indicates that the input is dynamic and does not imply broadcasting at that specific dimension.
266-
if isinstance(i, int) and isinstance(o, int) and i != DYNAMIC_DIM:
254+
input_tensor_shape = tuple(input_t.shape)
255+
for i, o in zip(input_tensor_shape, shape):
256+
# If input dim and target shape dim are static, broadcast if they are not equal
257+
# If input dim is known and target shape dim is dynamic we treat it as a broadcasted dim
258+
if (
259+
isinstance(i, int)
260+
and i != DYNAMIC_DIM
261+
and isinstance(o, int)
262+
and o != DYNAMIC_DIM
263+
):
267264
stride.append(int(i == o))
265+
elif isinstance(i, int) and i != DYNAMIC_DIM and isinstance(o, TRTTensor):
266+
stride.append(0)
268267
else:
268+
# No broadcasting is happening. The output should have the same size as input at this dimension.
269269
stride.append(1)
270270

271-
shape_ = shape_t
271+
# Resolve dynamic dimensions in the target shape. These are not broadcasted dims.
272+
# The value at this dimension should be same as input.
273+
target_shape = []
274+
for i in range(shape_rank):
275+
if shape[i] == DYNAMIC_DIM:
276+
target_shape.append(
277+
get_shape(ctx, target, source_ir, name + f"_shape_dim{i}", input_t, i)
278+
)
279+
else:
280+
target_shape.append(shape[i])
281+
282+
target_shape_t = target_shape
272283
# Handle dynamic shapes case where shape has dynamic dimension
273-
if any(isinstance(ele, TRTTensor) for ele in shape_t):
274-
shape_ = cat(
284+
if any(isinstance(ele, TRTTensor) for ele in target_shape_t):
285+
target_shape_t = cat(
275286
ctx,
276287
target,
277288
source_ir,
278289
name + "_shape_concat",
279-
shape_t,
290+
target_shape_t,
280291
0,
281292
cast_dtype=trt.int32,
282293
)
@@ -302,10 +313,12 @@ def expand(
302313
input_t, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
303314
)
304315
layer.set_input(1, start_tensor)
305-
layer.set_input(2, shape_)
316+
layer.set_input(2, target_shape_t)
306317
layer.set_input(3, stride_tensor)
307318
else:
308-
layer = ctx.net.add_slice(input_t, start=start, shape=shape_, stride=stride)
319+
layer = ctx.net.add_slice(
320+
input_t, start=start, shape=target_shape_t, stride=stride
321+
)
309322

310323
set_layer_name(layer, target, name, source_ir)
311324
return layer.get_output(0)

py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import logging
23
import operator
34
from typing import Callable, Sequence, Tuple
@@ -54,6 +55,13 @@ def lower_scaled_dot_product_attention(
5455
== torch.nn.functional.scaled_dot_product_attention
5556
)
5657

58+
# Copy the metadata of the replaced attention node to the new node
59+
# TODO: Investigate why there are multiple FakeTensors in the metadata.
60+
# We only use the first one as it contains the output shape information for this node.
61+
new_attention_node.meta["val"] = copy.copy(
62+
attention_node_replaced.meta["val"][0]
63+
)
64+
5765
# If the attention operator had keyword-args, copy them to the new node
5866
if attention_node_replaced.kwargs:
5967
new_attention_node.kwargs = {**attention_node_replaced.kwargs}

tests/py/dynamo/conversion/harness.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,14 @@ def run_test_with_dynamic_shape(
399399
)
400400
# Since the lowering is based on optimal shape. We need to test with
401401
# different shape(for ex. max shape) for testing dynamic shape
402-
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
402+
inputs_max = [
403+
(
404+
spec.example_tensor("max_shape")
405+
if spec.shape_mode == Input._ShapeMode.DYNAMIC
406+
else spec.example_tensor()
407+
)
408+
for spec in input_specs
409+
]
403410
if not use_example_tensors:
404411
inputs_max = [spec.torch_tensor for spec in input_specs]
405412
super().run_test(mod, inputs_max, interp, rtol, atol, pyt_inputs=pyt_inputs)

tests/py/dynamo/conversion/test_expand_aten.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ def forward(self, x):
3737
("different_ranks", (1, 2, 1), (1, 2, 1), (2, 2, 1), (2, -1, -1, -1)),
3838
]
3939
)
40-
def test_expand_dynamic(self, _, min_shape, opt_shape, max_shape, expanded_shape):
41-
class ExpandDynamic(nn.Module):
40+
def test_expand_dynamic_input(
41+
self, _, min_shape, opt_shape, max_shape, expanded_shape
42+
):
43+
class ExpandInputDynamic(nn.Module):
4244
def forward(self, x):
4345
return torch.ops.aten.expand.default(x, expanded_shape)
4446

@@ -51,10 +53,34 @@ def forward(self, x):
5153
),
5254
]
5355
self.run_test_with_dynamic_shape(
54-
ExpandDynamic(),
56+
ExpandInputDynamic(),
5557
input_specs,
5658
)
5759

60+
@parameterized.expand(
61+
[
62+
("3d_dim", (4, 1, 768), (1, 1, 768)),
63+
]
64+
)
65+
def test_expand_dynamic_target_shape(self, _, input_shape, weight_shape):
66+
class ExpandTargetDynamic(torch.nn.Module):
67+
def __init__(self, *args, **kwargs) -> None:
68+
super().__init__(*args, **kwargs)
69+
self.cls_token = torch.nn.Parameter(torch.randn(weight_shape).cuda())
70+
71+
def forward(self, x):
72+
batch_size = x.shape[0]
73+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
74+
embeddings = torch.cat((cls_tokens, x), dim=0)
75+
return embeddings
76+
77+
input_specs = [
78+
Input(dtype=torch.float32, shape=input_shape),
79+
]
80+
self.run_test_with_dynamic_shape(
81+
ExpandTargetDynamic(), input_specs, use_dynamo_tracer=True
82+
)
83+
5884

5985
if __name__ == "__main__":
6086
run_tests()

tests/py/dynamo/conversion/test_full_aten.py

+26
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,32 @@ def forward(self, shape):
5555
use_example_tensors=False,
5656
)
5757

58+
@parameterized.expand(
59+
[
60+
((1, 5, 3), (3, 7, 3), (4, 10, 4), 0.11),
61+
]
62+
)
63+
def test_full_dynamic_shape_list(self, min_shape, opt_shape, max_shape, fill_value):
64+
class full(nn.Module):
65+
def forward(self, x):
66+
shape = x.shape[0]
67+
target_shape = (shape, shape + 1)
68+
return torch.ops.aten.full.default(target_shape, fill_value)
69+
70+
inputs = [
71+
torch_tensorrt.Input(
72+
min_shape=min_shape,
73+
opt_shape=opt_shape,
74+
max_shape=max_shape,
75+
dtype=torch.int64,
76+
)
77+
]
78+
self.run_test_with_dynamic_shape(
79+
full(),
80+
inputs,
81+
use_dynamo_tracer=True,
82+
)
83+
5884

5985
if __name__ == "__main__":
6086
run_tests()

0 commit comments

Comments
 (0)