Skip to content

Commit 648466b

Browse files
adamomainzfacebook-github-bot
authored andcommitted
quick fix to continue with issue 71 (#73)
Summary: Pull Request resolved: #73 follow up diff Reviewed By: FindHao Differential Revision: D66372865 fbshipit-source-id: 94716dd4701949b54bd55a60ed87a99cadbf95e3
1 parent e8f5ba4 commit 648466b

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

tritonbench/operator_loader/__init__.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Any, Generator, List, Optional
55

66
import torch
7-
from torch._dynamo.backends.cudagraphs import cudagraphs_inner
87
from torch._inductor.utils import gen_gm_and_inputs
98
from torch._ops import OpOverload
109
from torch.utils._pytree import tree_map_only
@@ -85,6 +84,9 @@ def __init__(
8584
), f"AtenOpBenchmark only supports fp16 and fp32, but got {self.dtype}"
8685

8786
def get_input_iter(self) -> Generator:
87+
from torch._dynamo.backends.cudagraphs import cudagraphs_inner
88+
from torch._inductor.compile_fx import compile_fx
89+
8890
inps_gens = [self.huggingface_loader, self.torchbench_loader, self.timm_loader]
8991
for inp_gen in inps_gens:
9092
for inp in inp_gen.get_inputs_for_operator(
@@ -101,9 +103,6 @@ def get_input_iter(self) -> Generator:
101103
"aten::convolution_backward",
102104
)
103105
if self.device == "cuda":
104-
from torch._inductor.compile_fx import compile_fx
105-
106-
107106
cudagraph_eager = cudagraphs_inner(
108107
gm, gm_args, copy_outputs=False, copy_inputs=False
109108
)

0 commit comments

Comments
 (0)