4
4
from typing import Any , Generator , List , Optional
5
5
6
6
import torch
7
- from torch ._dynamo .backends .cudagraphs import cudagraphs_inner
8
7
from torch ._inductor .utils import gen_gm_and_inputs
9
8
from torch ._ops import OpOverload
10
9
from torch .utils ._pytree import tree_map_only
@@ -85,6 +84,9 @@ def __init__(
85
84
), f"AtenOpBenchmark only supports fp16 and fp32, but got { self .dtype } "
86
85
87
86
def get_input_iter (self ) -> Generator :
87
+ from torch ._dynamo .backends .cudagraphs import cudagraphs_inner
88
+ from torch ._inductor .compile_fx import compile_fx
89
+
88
90
inps_gens = [self .huggingface_loader , self .torchbench_loader , self .timm_loader ]
89
91
for inp_gen in inps_gens :
90
92
for inp in inp_gen .get_inputs_for_operator (
@@ -101,9 +103,6 @@ def get_input_iter(self) -> Generator:
101
103
"aten::convolution_backward" ,
102
104
)
103
105
if self .device == "cuda" :
104
- from torch ._inductor .compile_fx import compile_fx
105
-
106
-
107
106
cudagraph_eager = cudagraphs_inner (
108
107
gm , gm_args , copy_outputs = False , copy_inputs = False
109
108
)
0 commit comments