diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 0a52d1efd5..096bc1aa24 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -35,24 +35,25 @@ def cat( ) trt_inputs.append(each_input) - # Cast to promoted type for all inputs - promoted_type = trt_inputs[0].dtype - for each_input in trt_inputs[1:]: - promoted_type = _enums.dtype._from( - torch.promote_types( - _enums.dtype._from(promoted_type).to(torch.dtype), - _enums.dtype._from(each_input.dtype).to(torch.dtype), + if len(trt_inputs) > 1: + # Cast to promoted type for all inputs + promoted_type = trt_inputs[0].dtype + for each_input in trt_inputs[1:]: + promoted_type = _enums.dtype._from( + torch.promote_types( + _enums.dtype._from(promoted_type).to(torch.dtype), + _enums.dtype._from(each_input.dtype).to(torch.dtype), + ) ) - ) + trt_promoted_type = promoted_type.to(trt.DataType) - trt_promoted_type = promoted_type.to(trt.DataType) - trt_casted_inputs = [] - for i, each_input in enumerate(trt_inputs): - casted_input = cast_trt_tensor( - ctx, each_input, trt_promoted_type, f"{name}_input_casted_{i}" - ) - trt_casted_inputs.append(casted_input) - trt_inputs = trt_casted_inputs + trt_casted_inputs = [] + for i, each_input in enumerate(trt_inputs): + casted_input = cast_trt_tensor( + ctx, each_input, trt_promoted_type, f"{name}_input_casted_{i}" + ) + trt_casted_inputs.append(casted_input) + trt_inputs = trt_casted_inputs concat_layer = ctx.net.add_concatenation(trt_inputs) dim = get_positive_dim(dim, len(trt_inputs[0].shape))