|
1 | 1 | from typing import Optional, Union
|
2 | 2 |
|
3 | 3 | import numpy as np
|
| 4 | +import tensorrt as trt |
4 | 5 | import torch
|
5 | 6 | import torch_tensorrt.dynamo.conversion.impl as impl
|
6 | 7 | from torch.fx.node import Target
|
7 | 8 | from torch_tensorrt import _enums
|
8 | 9 | from torch_tensorrt.dynamo._SourceIR import SourceIR
|
9 | 10 | from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
|
10 | 11 | from torch_tensorrt.dynamo.conversion.converter_utils import (
|
| 12 | + broadcast, |
11 | 13 | cast_int_int_div_trt_tensor,
|
12 | 14 | cast_int_or_float_to_bool,
|
13 | 15 | cast_trt_tensor,
|
|
19 | 21 | )
|
20 | 22 | from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign
|
21 | 23 | from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
|
22 |
| -from torch_tensorrt.fx.converters.converter_utils import broadcast |
23 | 24 | from torch_tensorrt.fx.types import TRTTensor
|
24 | 25 |
|
25 |
| -import tensorrt as trt |
26 |
| - |
27 | 26 |
|
28 | 27 | def trunc_div(
|
29 | 28 | ctx: ConversionContext,
|
@@ -258,7 +257,7 @@ def atan2(
|
258 | 257 | if isinstance(other, TRTTensor):
|
259 | 258 | other = cast_trt_tensor(ctx, other, trt.float32, f"{name}_other")
|
260 | 259 |
|
261 |
| - input, other = broadcast(ctx.net, input, other, f"{name}_input", f"{name}_other") |
| 260 | + input, other = broadcast(ctx, input, other, f"{name}_input", f"{name}_other") |
262 | 261 |
|
263 | 262 | # Calculate x_zero, y_zero (whether inputs are zero)
|
264 | 263 | x_zero = eq(ctx, target, source_ir, f"{name}_x_zero", input, 0)
|
|
0 commit comments