Skip to content

Commit dde5e3c

Browse files
committed
chore: import converter_utils from dynamo
1 parent c7211e1 commit dde5e3c

File tree

2 files changed

+5
-8
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/elementwise

2 files changed

+5
-8
lines changed

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,10 @@
1010
from torch_tensorrt.dynamo._SourceIR import SourceIR
1111
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1212
from torch_tensorrt.dynamo.conversion.converter_utils import (
13+
broadcast,
1314
broadcast_to_same_shape,
1415
cast_trt_tensor,
1516
get_trt_tensor,
16-
)
17-
from torch_tensorrt.fx.converters.converter_utils import (
18-
broadcast,
1917
has_dynamic_shape,
2018
set_layer_name,
2119
)
@@ -152,7 +150,7 @@ def convert_binary_elementwise(
152150

153151
if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
154152
lhs_val, rhs_val = broadcast(
155-
ctx.net, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs"
153+
ctx, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs"
156154
)
157155
else:
158156
lhs_val, rhs_val = broadcast_to_same_shape(

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from typing import Optional, Union
22

33
import numpy as np
4+
import tensorrt as trt
45
import torch
56
import torch_tensorrt.dynamo.conversion.impl as impl
67
from torch.fx.node import Target
78
from torch_tensorrt import _enums
89
from torch_tensorrt.dynamo._SourceIR import SourceIR
910
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1011
from torch_tensorrt.dynamo.conversion.converter_utils import (
12+
broadcast,
1113
cast_int_int_div_trt_tensor,
1214
cast_int_or_float_to_bool,
1315
cast_trt_tensor,
@@ -19,11 +21,8 @@
1921
)
2022
from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign
2123
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
22-
from torch_tensorrt.fx.converters.converter_utils import broadcast
2324
from torch_tensorrt.fx.types import TRTTensor
2425

25-
import tensorrt as trt
26-
2726

2827
def trunc_div(
2928
ctx: ConversionContext,
@@ -258,7 +257,7 @@ def atan2(
258257
if isinstance(other, TRTTensor):
259258
other = cast_trt_tensor(ctx, other, trt.float32, f"{name}_other")
260259

261-
input, other = broadcast(ctx.net, input, other, f"{name}_input", f"{name}_other")
260+
input, other = broadcast(ctx, input, other, f"{name}_input", f"{name}_other")
262261

263262
# Calculate x_zero, y_zero (whether inputs are zero)
264263
x_zero = eq(ctx, target, source_ir, f"{name}_x_zero", input, 0)

0 commit comments

Comments
 (0)