Skip to content

Commit bc00de6

Browse files
authored
Overhaul upsample dynamo converter (#2790)
1 parent fe93ed5 commit bc00de6

File tree

5 files changed

+673
-159
lines changed

5 files changed

+673
-159
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+327-15
Original file line numberDiff line numberDiff line change
@@ -3041,9 +3041,175 @@ def aten_ops_pad(
30413041
)
30423042

30433043

3044-
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.default)
3045-
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.vec)
3046-
def upsample_nearest2d(
3044+
for op in (
3045+
torch.ops.aten.upsample_nearest1d,
3046+
torch.ops.aten.upsample_nearest2d,
3047+
torch.ops.aten.upsample_nearest3d,
3048+
torch.ops.aten.upsample_linear1d,
3049+
torch.ops.aten.upsample_bilinear2d,
3050+
torch.ops.aten.upsample_trilinear3d,
3051+
torch.ops.aten.upsample_bicubic2d,
3052+
):
3053+
for key in (
3054+
torch._C.DispatchKey.Autograd,
3055+
torch._C.DispatchKey.CompositeImplicitAutograd,
3056+
):
3057+
if key in op.default.py_kernels:
3058+
del op.default.py_kernels[key]
3059+
if key in op.vec.py_kernels:
3060+
del op.vec.py_kernels[key]
3061+
3062+
3063+
def upsample_compute_output_size(
3064+
input_size: torch.Size,
3065+
output_size: Optional[Sequence[int]],
3066+
scale_factors: Optional[Sequence[float]],
3067+
) -> Sequence[int]:
3068+
spatial_dimensions = len(input_size) - 2
3069+
3070+
if output_size is not None:
3071+
torch._check(
3072+
scale_factors is None,
3073+
lambda: "Must specify exactly one of output_size and scale_factors",
3074+
)
3075+
torch._check(len(output_size) == spatial_dimensions)
3076+
return output_size
3077+
3078+
if scale_factors is not None:
3079+
torch._check(
3080+
output_size is None,
3081+
lambda: "Must specify exactly one of output_size and scale_factors",
3082+
)
3083+
torch._check(len(scale_factors) == spatial_dimensions)
3084+
output_size = []
3085+
for i, s in enumerate(scale_factors):
3086+
output_size.append(int(input_size[i + 2] * s))
3087+
return output_size
3088+
3089+
torch._check(
3090+
False, lambda: "Must specify exactly one of output_size and scale_factors"
3091+
)
3092+
3093+
3094+
@torch.ops.aten.upsample_nearest1d.vec.py_impl(
3095+
torch._C.DispatchKey.CompositeImplicitAutograd
3096+
)
3097+
def upsample_nearest1d_vec(
3098+
input: torch.Tensor,
3099+
output_size: Optional[Sequence[int]],
3100+
scale_factors: Optional[Sequence[float]],
3101+
) -> torch.Tensor:
3102+
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3103+
if scale_factors is not None:
3104+
return torch.ops.aten.upsample_nearest1d.default(input, osize, *scale_factors)
3105+
return torch.ops.aten.upsample_nearest1d.default(input, osize)
3106+
3107+
3108+
@torch.ops.aten.upsample_nearest2d.vec.py_impl(
3109+
torch._C.DispatchKey.CompositeImplicitAutograd
3110+
)
3111+
def upsample_nearest2d_vec(
3112+
input: torch.Tensor,
3113+
output_size: Optional[Sequence[int]],
3114+
scale_factors: Optional[Sequence[float]],
3115+
) -> torch.Tensor:
3116+
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3117+
if scale_factors is not None:
3118+
return torch.ops.aten.upsample_nearest2d.default(input, osize, *scale_factors)
3119+
return torch.ops.aten.upsample_nearest2d.default(input, osize)
3120+
3121+
3122+
@torch.ops.aten.upsample_nearest3d.vec.py_impl(
3123+
torch._C.DispatchKey.CompositeImplicitAutograd
3124+
)
3125+
def upsample_nearest3d_vec(
3126+
input: torch.Tensor,
3127+
output_size: Optional[Sequence[int]],
3128+
scale_factors: Optional[Sequence[float]],
3129+
) -> torch.Tensor:
3130+
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3131+
if scale_factors is not None:
3132+
return torch.ops.aten.upsample_nearest3d.default(input, osize, *scale_factors)
3133+
return torch.ops.aten.upsample_nearest3d.default(input, osize)
3134+
3135+
3136+
@torch.ops.aten.upsample_linear1d.vec.py_impl(
3137+
torch._C.DispatchKey.CompositeImplicitAutograd
3138+
)
3139+
def upsample_linear1d_vec(
3140+
input: torch.Tensor,
3141+
output_size: Optional[Sequence[int]],
3142+
align_corners: bool,
3143+
scale_factors: Optional[Sequence[float]],
3144+
) -> torch.Tensor:
3145+
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3146+
if scale_factors is not None:
3147+
return torch.ops.aten.upsample_linear1d.default(
3148+
input, osize, align_corners, *scale_factors
3149+
)
3150+
return torch.ops.aten.upsample_linear1d.default(input, osize, align_corners)
3151+
3152+
3153+
@torch.ops.aten.upsample_bilinear2d.vec.py_impl(
3154+
torch._C.DispatchKey.CompositeImplicitAutograd
3155+
)
3156+
def upsample_bilinear2d_vec(
3157+
input: torch.Tensor,
3158+
output_size: Optional[Sequence[int]],
3159+
align_corners: bool,
3160+
scale_factors: Optional[Sequence[float]],
3161+
) -> torch.Tensor:
3162+
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3163+
if scale_factors is not None:
3164+
return torch.ops.aten.upsample_bilinear2d.default(
3165+
input, osize, align_corners, *scale_factors
3166+
)
3167+
return torch.ops.aten.upsample_bilinear2d.default(input, osize, align_corners)
3168+
3169+
3170+
@torch.ops.aten.upsample_trilinear3d.vec.py_impl(
3171+
torch._C.DispatchKey.CompositeImplicitAutograd
3172+
)
3173+
def upsample_trilinear3d_vec(
3174+
input: torch.Tensor,
3175+
output_size: Optional[Sequence[int]],
3176+
align_corners: bool,
3177+
scale_factors: Optional[Sequence[float]],
3178+
) -> torch.Tensor:
3179+
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3180+
if scale_factors is not None:
3181+
return torch.ops.aten.upsample_trilinear3d.default(
3182+
input, osize, align_corners, *scale_factors
3183+
)
3184+
return torch.ops.aten.upsample_trilinear3d.default(input, osize, align_corners)
3185+
3186+
3187+
@torch.ops.aten.upsample_bicubic2d.vec.py_impl(
3188+
torch._C.DispatchKey.CompositeImplicitAutograd
3189+
)
3190+
def upsample_bicubic2d_vec(
3191+
input: torch.Tensor,
3192+
output_size: Optional[Sequence[int]],
3193+
align_corners: bool,
3194+
scale_factors: Optional[Sequence[float]],
3195+
) -> torch.Tensor:
3196+
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3197+
if scale_factors is not None:
3198+
return torch.ops.aten.upsample_bicubic2d.default(
3199+
input, osize, align_corners, *scale_factors
3200+
)
3201+
return torch.ops.aten.upsample_bicubic2d.default(input, osize, align_corners)
3202+
3203+
3204+
@dynamo_tensorrt_converter(
3205+
torch.ops.aten.upsample_nearest1d.default, supports_dynamic_shapes=True
3206+
)
3207+
@enforce_tensor_types(
3208+
{
3209+
0: (TRTTensor,),
3210+
}
3211+
)
3212+
def aten_ops_upsample_nearest1d(
30473213
ctx: ConversionContext,
30483214
target: Target,
30493215
args: Tuple[Argument, ...],
@@ -3055,17 +3221,23 @@ def upsample_nearest2d(
30553221
target,
30563222
SourceIR.ATEN,
30573223
name,
3058-
input=args[0],
3059-
out_shape=args_bounds_check(args, 1),
3060-
scale_factors=args_bounds_check(args, 2),
3061-
resize_mode="nearest",
3224+
args[0],
3225+
size=args[1],
3226+
scale_factor=None if len(args) < 3 else [args[2]],
3227+
mode="nearest",
30623228
align_corners=False,
30633229
)
30643230

30653231

3066-
@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.default)
3067-
@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.vec)
3068-
def upsample_bilinear2d(
3232+
@dynamo_tensorrt_converter(
3233+
torch.ops.aten.upsample_nearest2d.default, supports_dynamic_shapes=True
3234+
)
3235+
@enforce_tensor_types(
3236+
{
3237+
0: (TRTTensor,),
3238+
}
3239+
)
3240+
def aten_ops_upsample_nearest2d(
30693241
ctx: ConversionContext,
30703242
target: Target,
30713243
args: Tuple[Argument, ...],
@@ -3077,11 +3249,151 @@ def upsample_bilinear2d(
30773249
target,
30783250
SourceIR.ATEN,
30793251
name,
3080-
input=args[0],
3081-
out_shape=args_bounds_check(args, 1),
3082-
scale_factors=args_bounds_check(args, 3),
3083-
resize_mode="bilinear",
3084-
align_corners=args_bounds_check(args, 2),
3252+
args[0],
3253+
size=args[1],
3254+
scale_factor=None if len(args) < 4 else [args[2], args[3]],
3255+
mode="nearest",
3256+
align_corners=False,
3257+
)
3258+
3259+
3260+
@dynamo_tensorrt_converter(
3261+
torch.ops.aten.upsample_nearest3d.default, supports_dynamic_shapes=True
3262+
)
3263+
@enforce_tensor_types(
3264+
{
3265+
0: (TRTTensor,),
3266+
}
3267+
)
3268+
def aten_ops_upsample_nearest3d(
3269+
ctx: ConversionContext,
3270+
target: Target,
3271+
args: Tuple[Argument, ...],
3272+
kwargs: Dict[str, Argument],
3273+
name: str,
3274+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3275+
return impl.upsample.upsample(
3276+
ctx,
3277+
target,
3278+
SourceIR.ATEN,
3279+
name,
3280+
args[0],
3281+
size=args[1],
3282+
scale_factor=None if len(args) < 5 else [args[2], args[3], args[4]],
3283+
mode="nearest",
3284+
align_corners=False,
3285+
)
3286+
3287+
3288+
@dynamo_tensorrt_converter(
3289+
torch.ops.aten.upsample_linear1d.default, supports_dynamic_shapes=True
3290+
)
3291+
@enforce_tensor_types(
3292+
{
3293+
0: (TRTTensor,),
3294+
}
3295+
)
3296+
def aten_ops_upsample_linear1d(
3297+
ctx: ConversionContext,
3298+
target: Target,
3299+
args: Tuple[Argument, ...],
3300+
kwargs: Dict[str, Argument],
3301+
name: str,
3302+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3303+
return impl.upsample.upsample(
3304+
ctx,
3305+
target,
3306+
SourceIR.ATEN,
3307+
name,
3308+
args[0],
3309+
size=args[1],
3310+
scale_factor=None if len(args) < 4 else [args[3]],
3311+
mode="linear",
3312+
align_corners=args[2],
3313+
)
3314+
3315+
3316+
@dynamo_tensorrt_converter(
3317+
torch.ops.aten.upsample_bilinear2d.default, supports_dynamic_shapes=True
3318+
)
3319+
@enforce_tensor_types(
3320+
{
3321+
0: (TRTTensor,),
3322+
}
3323+
)
3324+
def aten_ops_upsample_bilinear2d(
3325+
ctx: ConversionContext,
3326+
target: Target,
3327+
args: Tuple[Argument, ...],
3328+
kwargs: Dict[str, Argument],
3329+
name: str,
3330+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3331+
return impl.upsample.upsample(
3332+
ctx,
3333+
target,
3334+
SourceIR.ATEN,
3335+
name,
3336+
args[0],
3337+
size=args[1],
3338+
scale_factor=None if len(args) < 5 else [args[3], args[4]],
3339+
mode="bilinear",
3340+
align_corners=args[2],
3341+
)
3342+
3343+
3344+
@dynamo_tensorrt_converter(
3345+
torch.ops.aten.upsample_trilinear3d.default, supports_dynamic_shapes=True
3346+
)
3347+
@enforce_tensor_types(
3348+
{
3349+
0: (TRTTensor,),
3350+
}
3351+
)
3352+
def aten_ops_upsample_trilinear3d(
3353+
ctx: ConversionContext,
3354+
target: Target,
3355+
args: Tuple[Argument, ...],
3356+
kwargs: Dict[str, Argument],
3357+
name: str,
3358+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3359+
return impl.upsample.upsample(
3360+
ctx,
3361+
target,
3362+
SourceIR.ATEN,
3363+
name,
3364+
args[0],
3365+
size=args[1],
3366+
scale_factor=None if len(args) < 6 else [args[3], args[4], args[5]],
3367+
mode="trilinear",
3368+
align_corners=args[2],
3369+
)
3370+
3371+
3372+
@dynamo_tensorrt_converter(
3373+
torch.ops.aten.upsample_bicubic2d.default, supports_dynamic_shapes=True
3374+
)
3375+
@enforce_tensor_types(
3376+
{
3377+
0: (TRTTensor,),
3378+
}
3379+
)
3380+
def aten_ops_upsample_bicubic2d(
3381+
ctx: ConversionContext,
3382+
target: Target,
3383+
args: Tuple[Argument, ...],
3384+
kwargs: Dict[str, Argument],
3385+
name: str,
3386+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3387+
return impl.upsample.upsample(
3388+
ctx,
3389+
target,
3390+
SourceIR.ATEN,
3391+
name,
3392+
args[0],
3393+
size=args[1],
3394+
scale_factor=None if len(args) < 5 else [args[3], args[4]],
3395+
mode="bicubic",
3396+
align_corners=args[2],
30853397
)
30863398

30873399

0 commit comments

Comments
 (0)