@@ -3041,9 +3041,175 @@ def aten_ops_pad(
3041
3041
)
3042
3042
3043
3043
3044
- @dynamo_tensorrt_converter (torch .ops .aten .upsample_nearest2d .default )
3045
- @dynamo_tensorrt_converter (torch .ops .aten .upsample_nearest2d .vec )
3046
- def upsample_nearest2d (
3044
+ for op in (
3045
+ torch .ops .aten .upsample_nearest1d ,
3046
+ torch .ops .aten .upsample_nearest2d ,
3047
+ torch .ops .aten .upsample_nearest3d ,
3048
+ torch .ops .aten .upsample_linear1d ,
3049
+ torch .ops .aten .upsample_bilinear2d ,
3050
+ torch .ops .aten .upsample_trilinear3d ,
3051
+ torch .ops .aten .upsample_bicubic2d ,
3052
+ ):
3053
+ for key in (
3054
+ torch ._C .DispatchKey .Autograd ,
3055
+ torch ._C .DispatchKey .CompositeImplicitAutograd ,
3056
+ ):
3057
+ if key in op .default .py_kernels :
3058
+ del op .default .py_kernels [key ]
3059
+ if key in op .vec .py_kernels :
3060
+ del op .vec .py_kernels [key ]
3061
+
3062
+
3063
+ def upsample_compute_output_size (
3064
+ input_size : torch .Size ,
3065
+ output_size : Optional [Sequence [int ]],
3066
+ scale_factors : Optional [Sequence [float ]],
3067
+ ) -> Sequence [int ]:
3068
+ spatial_dimensions = len (input_size ) - 2
3069
+
3070
+ if output_size is not None :
3071
+ torch ._check (
3072
+ scale_factors is None ,
3073
+ lambda : "Must specify exactly one of output_size and scale_factors" ,
3074
+ )
3075
+ torch ._check (len (output_size ) == spatial_dimensions )
3076
+ return output_size
3077
+
3078
+ if scale_factors is not None :
3079
+ torch ._check (
3080
+ output_size is None ,
3081
+ lambda : "Must specify exactly one of output_size and scale_factors" ,
3082
+ )
3083
+ torch ._check (len (scale_factors ) == spatial_dimensions )
3084
+ output_size = []
3085
+ for i , s in enumerate (scale_factors ):
3086
+ output_size .append (int (input_size [i + 2 ] * s ))
3087
+ return output_size
3088
+
3089
+ torch ._check (
3090
+ False , lambda : "Must specify exactly one of output_size and scale_factors"
3091
+ )
3092
+
3093
+
3094
+ @torch .ops .aten .upsample_nearest1d .vec .py_impl (
3095
+ torch ._C .DispatchKey .CompositeImplicitAutograd
3096
+ )
3097
+ def upsample_nearest1d_vec (
3098
+ input : torch .Tensor ,
3099
+ output_size : Optional [Sequence [int ]],
3100
+ scale_factors : Optional [Sequence [float ]],
3101
+ ) -> torch .Tensor :
3102
+ osize = upsample_compute_output_size (input .size (), output_size , scale_factors )
3103
+ if scale_factors is not None :
3104
+ return torch .ops .aten .upsample_nearest1d .default (input , osize , * scale_factors )
3105
+ return torch .ops .aten .upsample_nearest1d .default (input , osize )
3106
+
3107
+
3108
+ @torch .ops .aten .upsample_nearest2d .vec .py_impl (
3109
+ torch ._C .DispatchKey .CompositeImplicitAutograd
3110
+ )
3111
+ def upsample_nearest2d_vec (
3112
+ input : torch .Tensor ,
3113
+ output_size : Optional [Sequence [int ]],
3114
+ scale_factors : Optional [Sequence [float ]],
3115
+ ) -> torch .Tensor :
3116
+ osize = upsample_compute_output_size (input .size (), output_size , scale_factors )
3117
+ if scale_factors is not None :
3118
+ return torch .ops .aten .upsample_nearest2d .default (input , osize , * scale_factors )
3119
+ return torch .ops .aten .upsample_nearest2d .default (input , osize )
3120
+
3121
+
3122
+ @torch .ops .aten .upsample_nearest3d .vec .py_impl (
3123
+ torch ._C .DispatchKey .CompositeImplicitAutograd
3124
+ )
3125
+ def upsample_nearest3d_vec (
3126
+ input : torch .Tensor ,
3127
+ output_size : Optional [Sequence [int ]],
3128
+ scale_factors : Optional [Sequence [float ]],
3129
+ ) -> torch .Tensor :
3130
+ osize = upsample_compute_output_size (input .size (), output_size , scale_factors )
3131
+ if scale_factors is not None :
3132
+ return torch .ops .aten .upsample_nearest3d .default (input , osize , * scale_factors )
3133
+ return torch .ops .aten .upsample_nearest3d .default (input , osize )
3134
+
3135
+
3136
+ @torch .ops .aten .upsample_linear1d .vec .py_impl (
3137
+ torch ._C .DispatchKey .CompositeImplicitAutograd
3138
+ )
3139
+ def upsample_linear1d_vec (
3140
+ input : torch .Tensor ,
3141
+ output_size : Optional [Sequence [int ]],
3142
+ align_corners : bool ,
3143
+ scale_factors : Optional [Sequence [float ]],
3144
+ ) -> torch .Tensor :
3145
+ osize = upsample_compute_output_size (input .size (), output_size , scale_factors )
3146
+ if scale_factors is not None :
3147
+ return torch .ops .aten .upsample_linear1d .default (
3148
+ input , osize , align_corners , * scale_factors
3149
+ )
3150
+ return torch .ops .aten .upsample_linear1d .default (input , osize , align_corners )
3151
+
3152
+
3153
+ @torch .ops .aten .upsample_bilinear2d .vec .py_impl (
3154
+ torch ._C .DispatchKey .CompositeImplicitAutograd
3155
+ )
3156
+ def upsample_bilinear2d_vec (
3157
+ input : torch .Tensor ,
3158
+ output_size : Optional [Sequence [int ]],
3159
+ align_corners : bool ,
3160
+ scale_factors : Optional [Sequence [float ]],
3161
+ ) -> torch .Tensor :
3162
+ osize = upsample_compute_output_size (input .size (), output_size , scale_factors )
3163
+ if scale_factors is not None :
3164
+ return torch .ops .aten .upsample_bilinear2d .default (
3165
+ input , osize , align_corners , * scale_factors
3166
+ )
3167
+ return torch .ops .aten .upsample_bilinear2d .default (input , osize , align_corners )
3168
+
3169
+
3170
+ @torch .ops .aten .upsample_trilinear3d .vec .py_impl (
3171
+ torch ._C .DispatchKey .CompositeImplicitAutograd
3172
+ )
3173
+ def upsample_trilinear3d_vec (
3174
+ input : torch .Tensor ,
3175
+ output_size : Optional [Sequence [int ]],
3176
+ align_corners : bool ,
3177
+ scale_factors : Optional [Sequence [float ]],
3178
+ ) -> torch .Tensor :
3179
+ osize = upsample_compute_output_size (input .size (), output_size , scale_factors )
3180
+ if scale_factors is not None :
3181
+ return torch .ops .aten .upsample_trilinear3d .default (
3182
+ input , osize , align_corners , * scale_factors
3183
+ )
3184
+ return torch .ops .aten .upsample_trilinear3d .default (input , osize , align_corners )
3185
+
3186
+
3187
+ @torch .ops .aten .upsample_bicubic2d .vec .py_impl (
3188
+ torch ._C .DispatchKey .CompositeImplicitAutograd
3189
+ )
3190
+ def upsample_bicubic2d_vec (
3191
+ input : torch .Tensor ,
3192
+ output_size : Optional [Sequence [int ]],
3193
+ align_corners : bool ,
3194
+ scale_factors : Optional [Sequence [float ]],
3195
+ ) -> torch .Tensor :
3196
+ osize = upsample_compute_output_size (input .size (), output_size , scale_factors )
3197
+ if scale_factors is not None :
3198
+ return torch .ops .aten .upsample_bicubic2d .default (
3199
+ input , osize , align_corners , * scale_factors
3200
+ )
3201
+ return torch .ops .aten .upsample_bicubic2d .default (input , osize , align_corners )
3202
+
3203
+
3204
+ @dynamo_tensorrt_converter (
3205
+ torch .ops .aten .upsample_nearest1d .default , supports_dynamic_shapes = True
3206
+ )
3207
+ @enforce_tensor_types (
3208
+ {
3209
+ 0 : (TRTTensor ,),
3210
+ }
3211
+ )
3212
+ def aten_ops_upsample_nearest1d (
3047
3213
ctx : ConversionContext ,
3048
3214
target : Target ,
3049
3215
args : Tuple [Argument , ...],
@@ -3055,17 +3221,23 @@ def upsample_nearest2d(
3055
3221
target ,
3056
3222
SourceIR .ATEN ,
3057
3223
name ,
3058
- input = args [0 ],
3059
- out_shape = args_bounds_check ( args , 1 ) ,
3060
- scale_factors = args_bounds_check (args , 2 ) ,
3061
- resize_mode = "nearest" ,
3224
+ args [0 ],
3225
+ size = args [ 1 ] ,
3226
+ scale_factor = None if len (args ) < 3 else [ args [ 2 ]] ,
3227
+ mode = "nearest" ,
3062
3228
align_corners = False ,
3063
3229
)
3064
3230
3065
3231
3066
- @dynamo_tensorrt_converter (torch .ops .aten .upsample_bilinear2d .default )
3067
- @dynamo_tensorrt_converter (torch .ops .aten .upsample_bilinear2d .vec )
3068
- def upsample_bilinear2d (
3232
+ @dynamo_tensorrt_converter (
3233
+ torch .ops .aten .upsample_nearest2d .default , supports_dynamic_shapes = True
3234
+ )
3235
+ @enforce_tensor_types (
3236
+ {
3237
+ 0 : (TRTTensor ,),
3238
+ }
3239
+ )
3240
+ def aten_ops_upsample_nearest2d (
3069
3241
ctx : ConversionContext ,
3070
3242
target : Target ,
3071
3243
args : Tuple [Argument , ...],
@@ -3077,11 +3249,151 @@ def upsample_bilinear2d(
3077
3249
target ,
3078
3250
SourceIR .ATEN ,
3079
3251
name ,
3080
- input = args [0 ],
3081
- out_shape = args_bounds_check (args , 1 ),
3082
- scale_factors = args_bounds_check (args , 3 ),
3083
- resize_mode = "bilinear" ,
3084
- align_corners = args_bounds_check (args , 2 ),
3252
+ args [0 ],
3253
+ size = args [1 ],
3254
+ scale_factor = None if len (args ) < 4 else [args [2 ], args [3 ]],
3255
+ mode = "nearest" ,
3256
+ align_corners = False ,
3257
+ )
3258
+
3259
+
3260
+ @dynamo_tensorrt_converter (
3261
+ torch .ops .aten .upsample_nearest3d .default , supports_dynamic_shapes = True
3262
+ )
3263
+ @enforce_tensor_types (
3264
+ {
3265
+ 0 : (TRTTensor ,),
3266
+ }
3267
+ )
3268
+ def aten_ops_upsample_nearest3d (
3269
+ ctx : ConversionContext ,
3270
+ target : Target ,
3271
+ args : Tuple [Argument , ...],
3272
+ kwargs : Dict [str , Argument ],
3273
+ name : str ,
3274
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
3275
+ return impl .upsample .upsample (
3276
+ ctx ,
3277
+ target ,
3278
+ SourceIR .ATEN ,
3279
+ name ,
3280
+ args [0 ],
3281
+ size = args [1 ],
3282
+ scale_factor = None if len (args ) < 5 else [args [2 ], args [3 ], args [4 ]],
3283
+ mode = "nearest" ,
3284
+ align_corners = False ,
3285
+ )
3286
+
3287
+
3288
+ @dynamo_tensorrt_converter (
3289
+ torch .ops .aten .upsample_linear1d .default , supports_dynamic_shapes = True
3290
+ )
3291
+ @enforce_tensor_types (
3292
+ {
3293
+ 0 : (TRTTensor ,),
3294
+ }
3295
+ )
3296
+ def aten_ops_upsample_linear1d (
3297
+ ctx : ConversionContext ,
3298
+ target : Target ,
3299
+ args : Tuple [Argument , ...],
3300
+ kwargs : Dict [str , Argument ],
3301
+ name : str ,
3302
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
3303
+ return impl .upsample .upsample (
3304
+ ctx ,
3305
+ target ,
3306
+ SourceIR .ATEN ,
3307
+ name ,
3308
+ args [0 ],
3309
+ size = args [1 ],
3310
+ scale_factor = None if len (args ) < 4 else [args [3 ]],
3311
+ mode = "linear" ,
3312
+ align_corners = args [2 ],
3313
+ )
3314
+
3315
+
3316
+ @dynamo_tensorrt_converter (
3317
+ torch .ops .aten .upsample_bilinear2d .default , supports_dynamic_shapes = True
3318
+ )
3319
+ @enforce_tensor_types (
3320
+ {
3321
+ 0 : (TRTTensor ,),
3322
+ }
3323
+ )
3324
+ def aten_ops_upsample_bilinear2d (
3325
+ ctx : ConversionContext ,
3326
+ target : Target ,
3327
+ args : Tuple [Argument , ...],
3328
+ kwargs : Dict [str , Argument ],
3329
+ name : str ,
3330
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
3331
+ return impl .upsample .upsample (
3332
+ ctx ,
3333
+ target ,
3334
+ SourceIR .ATEN ,
3335
+ name ,
3336
+ args [0 ],
3337
+ size = args [1 ],
3338
+ scale_factor = None if len (args ) < 5 else [args [3 ], args [4 ]],
3339
+ mode = "bilinear" ,
3340
+ align_corners = args [2 ],
3341
+ )
3342
+
3343
+
3344
+ @dynamo_tensorrt_converter (
3345
+ torch .ops .aten .upsample_trilinear3d .default , supports_dynamic_shapes = True
3346
+ )
3347
+ @enforce_tensor_types (
3348
+ {
3349
+ 0 : (TRTTensor ,),
3350
+ }
3351
+ )
3352
+ def aten_ops_upsample_trilinear3d (
3353
+ ctx : ConversionContext ,
3354
+ target : Target ,
3355
+ args : Tuple [Argument , ...],
3356
+ kwargs : Dict [str , Argument ],
3357
+ name : str ,
3358
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
3359
+ return impl .upsample .upsample (
3360
+ ctx ,
3361
+ target ,
3362
+ SourceIR .ATEN ,
3363
+ name ,
3364
+ args [0 ],
3365
+ size = args [1 ],
3366
+ scale_factor = None if len (args ) < 6 else [args [3 ], args [4 ], args [5 ]],
3367
+ mode = "trilinear" ,
3368
+ align_corners = args [2 ],
3369
+ )
3370
+
3371
+
3372
+ @dynamo_tensorrt_converter (
3373
+ torch .ops .aten .upsample_bicubic2d .default , supports_dynamic_shapes = True
3374
+ )
3375
+ @enforce_tensor_types (
3376
+ {
3377
+ 0 : (TRTTensor ,),
3378
+ }
3379
+ )
3380
+ def aten_ops_upsample_bicubic2d (
3381
+ ctx : ConversionContext ,
3382
+ target : Target ,
3383
+ args : Tuple [Argument , ...],
3384
+ kwargs : Dict [str , Argument ],
3385
+ name : str ,
3386
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
3387
+ return impl .upsample .upsample (
3388
+ ctx ,
3389
+ target ,
3390
+ SourceIR .ATEN ,
3391
+ name ,
3392
+ args [0 ],
3393
+ size = args [1 ],
3394
+ scale_factor = None if len (args ) < 5 else [args [3 ], args [4 ]],
3395
+ mode = "bicubic" ,
3396
+ align_corners = args [2 ],
3085
3397
)
3086
3398
3087
3399
0 commit comments