Skip to content

Commit

Permalink
Fix donated_buffer issue (#113)
Browse files Browse the repository at this point in the history
Summary:
The previous PR #104 causes the following issue.
```
% python run.py --op geglu --mode fwd  --precision fp32 --metrics latency,speedup --csv --cudagraph
  0%|                                                                                                                           | 0/4 [00:03<?, ?it/s]
Caught exception, terminating early with partial results
Traceback (most recent call last):
  File "/scratch/yhao/pta/tritonbench/tritonbench/utils/triton_op.py", line 782, in run
    y_vals: Dict[str, BenchmarkOperatorMetrics] = functools.reduce(
  File "/scratch/yhao/pta/tritonbench/tritonbench/utils/triton_op.py", line 770, in _reduce_benchmarks
    acc[bm_name] = self._do_bench(
  File "/scratch/yhao/pta/tritonbench/tritonbench/utils/triton_op.py", line 981, in _do_bench
    fn = self._get_bm_func(fn_name)
  File "/scratch/yhao/pta/tritonbench/tritonbench/utils/triton_op.py", line 667, in _get_bm_func
    fwd_fn = fwd_fn_lambda(*self.example_inputs)
  File "/scratch/yhao/pta/tritonbench/tritonbench/utils/triton_op.py", line 481, in _inner
    return function(self, *args, **kwargs)
  File "/scratch/yhao/pta/tritonbench/tritonbench/operators/geglu/operator.py", line 69, in inductor_geglu
    compiled = torch.compile(self.baseline_model)
UnboundLocalError: local variable 'torch' referenced before assignment
(B, T, H)
```
we should use `from torch._functorch import config` rather than `import torch._functorch.config`

Pull Request resolved: #113

Reviewed By: adamomainz

Differential Revision: D67110110

Pulled By: FindHao

fbshipit-source-id: e5143b06d0e62fb2a7b83464e23126e73a52ee10
  • Loading branch information
FindHao authored and facebook-github-bot committed Dec 12, 2024
1 parent 7b67b0a commit 3adf655
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
4 changes: 3 additions & 1 deletion tritonbench/operators/geglu/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def inductor_geglu(self, input) -> Callable:
# We need to run backward multiple times for proper benchmarking
# so donated buffer have to be disabled
if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD:
import torch._functorch.config
from torch._functorch import config as functorch_config

functorch_config.donated_buffer = False

compiled = torch.compile(self.baseline_model)
return lambda: compiled(input)
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def torch_compile_layer_norm(self, *args):
# We need to run backward multiple times for proper benchmarking
# so donated buffer have to be disabled
if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD:
import torch._functorch.config
from torch._functorch import config as functorch_config

torch._functorch.config.donated_buffer = False
functorch_config.donated_buffer = False
import torch

@torch.compile
Expand Down
4 changes: 3 additions & 1 deletion tritonbench/operators/swiglu/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def inductor_swiglu(self, input) -> Callable:
# We need to run backward multiple times for proper benchmarking
# so donated buffer have to be disabled
if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD:
import torch._functorch.config
from torch._functorch import config as functorch_config

functorch_config.donated_buffer = False

compiled = torch.compile(self.baseline_op)
return lambda: compiled(input)
Expand Down

0 comments on commit 3adf655

Please sign in to comment.