Skip to content

Commit

Permalink
chore: promoted type in cat ops for only multiple input
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Sep 4, 2024
1 parent 99f76f7 commit c7211e1
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions py/torch_tensorrt/dynamo/conversion/impl/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,25 @@ def cat(
)
trt_inputs.append(each_input)

# Cast to promoted type for all inputs
promoted_type = trt_inputs[0].dtype
for each_input in trt_inputs[1:]:
promoted_type = _enums.dtype._from(
torch.promote_types(
_enums.dtype._from(promoted_type).to(torch.dtype),
_enums.dtype._from(each_input.dtype).to(torch.dtype),
if len(trt_inputs) > 1:
# Cast to promoted type for all inputs
promoted_type = trt_inputs[0].dtype
for each_input in trt_inputs[1:]:
promoted_type = _enums.dtype._from(
torch.promote_types(
_enums.dtype._from(promoted_type).to(torch.dtype),
_enums.dtype._from(each_input.dtype).to(torch.dtype),
)
)
)
trt_promoted_type = promoted_type.to(trt.DataType)

trt_promoted_type = promoted_type.to(trt.DataType)
trt_casted_inputs = []
for i, each_input in enumerate(trt_inputs):
casted_input = cast_trt_tensor(
ctx, each_input, trt_promoted_type, f"{name}_input_casted_{i}"
)
trt_casted_inputs.append(casted_input)
trt_inputs = trt_casted_inputs
trt_casted_inputs = []
for i, each_input in enumerate(trt_inputs):
casted_input = cast_trt_tensor(
ctx, each_input, trt_promoted_type, f"{name}_input_casted_{i}"
)
trt_casted_inputs.append(casted_input)
trt_inputs = trt_casted_inputs

concat_layer = ctx.net.add_concatenation(trt_inputs)
dim = get_positive_dim(dim, len(trt_inputs[0].shape))
Expand Down

0 comments on commit c7211e1

Please sign in to comment.