1- from typing import Optional , Sequence , Union
1+ from typing import List , Optional , Sequence , Union
22
33import numpy as np
44import tensorrt as trt
1111from torch_tensorrt .dynamo .conversion .converter_utils import (
1212 cast_trt_tensor ,
1313 get_positive_dim ,
14- get_trt_tensor ,
1514 set_layer_name ,
1615)
1716
1817
18+ def unify_and_concat_trt_tensors (
19+ ctx : ConversionContext ,
20+ target : Target ,
21+ name : str ,
22+ inputs : Sequence [Union [int , np .ndarray , torch .Tensor , TRTTensor ]],
23+ concat_axis : int ,
24+ cast_dtype : Union [_enums .dtype , trt .DataType , np .dtype ] = None ,
25+ force_trt_output : bool = False ,
26+ ) -> Union [TRTTensor , List [int ]]:
27+ """
28+ Normalize all inputs to TRT tensors if needed, optionally cast, and concat if any dynamic.
29+
30+ Args:
31+ ctx: TensorRT conversion context.
32+ target: Operation Target.
33+ name: Operation Name.
34+ inputs: Sequence of ints / numpy arrays / torch tensors / TRT tensors.
35+ concat_axis: Axis along which to concatenate tensors if dynamic.
36+ cast_dtype: Optional target dtype for casting TRT tensors.
37+ force_trt_output: If True, return TRT tensor even if all inputs are static ints. (True for concat operations)
38+ """
39+ has_dynamic = any (not isinstance (x , int ) for x in inputs )
40+ trt_tensors = []
41+
42+ for i , x in enumerate (inputs ):
43+ # convert to TRTTensor
44+ if isinstance (x , TRTTensor ):
45+ t = x
46+ elif isinstance (x , int ) and not has_dynamic and not force_trt_output :
47+ t = x # pure static path
48+ else :
49+ const_arr = np .array ([x ], dtype = np .int32 )
50+ shape = (1 ,)
51+ if not isinstance (x , int ):
52+ const_arr = np .array (x , dtype = np .int32 )
53+ shape = (x .numel (),)
54+
55+ layer = ctx .net .add_constant (shape , const_arr )
56+ set_layer_name (layer , target , f"{ name } _dim{ i } _const" )
57+ t = layer .get_output (0 )
58+ trt_tensors .append (t )
59+
60+ if not has_dynamic and not force_trt_output :
61+ return trt_tensors # all ints
62+
63+ final_dtype = None
64+ if cast_dtype :
65+ # Explicit cast requested
66+ if isinstance (cast_dtype , _enums .dtype ):
67+ final_dtype = cast_dtype .to (trt .DataType )
68+ elif isinstance (cast_dtype , (np .dtype , torch .dtype )):
69+ final_dtype = _enums .dtype ._from (cast_dtype ).to (trt .DataType )
70+ else :
71+ final_dtype = cast_dtype # already trt.DataType
72+ else :
73+ # Automatic promotion
74+ promoted_type = None
75+ for t in trt_tensors :
76+ if isinstance (t , TRTTensor ):
77+ if promoted_type is None :
78+ promoted_type = t .dtype
79+ else :
80+ promoted_type = _enums .dtype ._from (
81+ torch .promote_types (
82+ _enums .dtype ._from (promoted_type ).to (torch .dtype ),
83+ _enums .dtype ._from (t .dtype ).to (torch .dtype ),
84+ )
85+ ).to (trt .DataType )
86+ final_dtype = promoted_type
87+
88+ # promote remaining ints to TRT consts before concat
89+ for i , t in enumerate (trt_tensors ):
90+ if isinstance (t , int ):
91+ const = ctx .net .add_constant ((1 ,), np .array ([t ], dtype = np .int32 ))
92+ set_layer_name (const , target , f"{ name } _static_{ i } _const" )
93+ trt_tensors [i ] = const .get_output (0 )
94+
95+ # final cast
96+ if final_dtype is not None :
97+ casted = []
98+ for i , t in enumerate (trt_tensors ):
99+ if isinstance (t , TRTTensor ):
100+ t = cast_trt_tensor (ctx , t , final_dtype , f"{ name } _cast_{ i } " )
101+ casted .append (t )
102+ trt_tensors = casted
103+
104+ concat = ctx .net .add_concatenation (trt_tensors )
105+ concat .axis = concat_axis
106+ set_layer_name (concat , target , f"{ name } _concat" )
107+ return concat .get_output (0 )
108+
109+
19110def cat (
20111 ctx : ConversionContext ,
21112 target : Target ,
@@ -25,38 +116,17 @@ def cat(
25116 dim : int ,
26117 cast_dtype : Union [_enums .dtype , trt .DataType , np .dtype ] = None ,
27118) -> Union [TRTTensor , Sequence [TRTTensor ]]:
28- trt_inputs = []
29- for i , each_input in enumerate (input ):
30- if not isinstance (each_input , TRTTensor ):
31- each_input = get_trt_tensor (ctx , each_input , f"{ name } _tensor_{ i } " )
32- if cast_dtype :
33- each_input = cast_trt_tensor (
34- ctx , each_input , cast_dtype , f"{ name } _tensor_int32_cast_{ i } "
35- )
36- trt_inputs .append (each_input )
37-
38- if len (trt_inputs ) > 1 :
39- # Cast to promoted type for all inputs
40- promoted_type = trt_inputs [0 ].dtype
41- for each_input in trt_inputs [1 :]:
42- promoted_type = _enums .dtype ._from (
43- torch .promote_types (
44- _enums .dtype ._from (promoted_type ).to (torch .dtype ),
45- _enums .dtype ._from (each_input .dtype ).to (torch .dtype ),
46- )
47- )
48- trt_promoted_type = promoted_type .to (trt .DataType )
49-
50- trt_casted_inputs = []
51- for i , each_input in enumerate (trt_inputs ):
52- casted_input = cast_trt_tensor (
53- ctx , each_input , trt_promoted_type , f"{ name } _input_casted_{ i } "
54- )
55- trt_casted_inputs .append (casted_input )
56- trt_inputs = trt_casted_inputs
57-
58- concat_layer = ctx .net .add_concatenation (trt_inputs )
59- dim = get_positive_dim (dim , len (trt_inputs [0 ].shape ))
60- concat_layer .axis = dim
61- set_layer_name (concat_layer , target , f"{ name } _gather" , source_ir )
62- return concat_layer .get_output (0 )
119+ # int is only when cat called in other ops like pad
120+ if not isinstance (input [0 ], int ):
121+ dim = get_positive_dim (dim , len (input [0 ].shape ))
122+ else :
123+ dim = 0
124+ return unify_and_concat_trt_tensors (
125+ ctx ,
126+ target ,
127+ name ,
128+ input ,
129+ concat_axis = dim ,
130+ cast_dtype = cast_dtype ,
131+ force_trt_output = True ,
132+ )
0 commit comments