From 9534e3905bd17090442291a0cece3c824bc3e0e5 Mon Sep 17 00:00:00 2001 From: Yueming Hao Date: Fri, 6 Dec 2024 16:54:04 -0800 Subject: [PATCH 1/6] disable donated_buffer for all ops's backend benchmarking --- tritonbench/operators/layer_norm/operator.py | 8 -------- tritonbench/utils/triton_op.py | 2 ++ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index ecfc4444..6627697c 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -34,14 +34,6 @@ def torch_layer_norm(self, *args): @register_benchmark() def torch_compile_layer_norm(self, *args): - # We need to run backward multiple times for proper benchmarking - # so donated buffer have to be disabled - if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: - import torch._functorch.config - - torch._functorch.config.donated_buffer = False - import torch - @torch.compile def inner(*args): return F.layer_norm(*args) diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 0a0aa962..60af5736 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -40,6 +40,8 @@ tqdm = None logger = logging.getLogger(__name__) +# TODO: remove this once we have a better way to handle backward benchmarking +torch._functorch.config.donated_buffer = False @dataclass From 121571b80343eb092770404b98f9ca95bbc1815e Mon Sep 17 00:00:00 2001 From: Yueming Hao Date: Fri, 6 Dec 2024 17:22:42 -0800 Subject: [PATCH 2/6] disable donated buffer for bwd and fwd_bwd specifically --- tritonbench/utils/triton_op.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 60af5736..1c573973 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -40,8 +40,6 @@ tqdm = None logger = logging.getLogger(__name__) -# TODO: remove this once we have a better way to handle backward benchmarking -torch._functorch.config.donated_buffer = False @dataclass @@ -622,6 +620,9 @@ def __init__( self.tb_args.mode == "bwd" ), "We only accept test modes: fwd, bwd, fwd_bwd, or fwd_no_grad." self.mode = Mode.BWD + if self.mode in [Mode.FWD_BWD, Mode.BWD]: + # TODO: remove this once we have a better way to handle backward benchmarking + torch._functorch.config.donated_buffer = False self.device = tb_args.device self.required_metrics = ( list(set(tb_args.metrics.split(","))) From cdcd06d210d3a87451b8665234bf7d040055fcbb Mon Sep 17 00:00:00 2001 From: Yueming Hao Date: Mon, 9 Dec 2024 10:03:56 -0800 Subject: [PATCH 3/6] import torch._functorch.config --- tritonbench/utils/triton_op.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 1c573973..0a5ce4fc 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -622,6 +622,7 @@ def __init__( self.mode = Mode.BWD if self.mode in [Mode.FWD_BWD, Mode.BWD]: # TODO: remove this once we have a better way to handle backward benchmarking + import torch._functorch.config torch._functorch.config.donated_buffer = False self.device = tb_args.device self.required_metrics = ( From 5aa17b64701ee53e62e7885e2999dfb8c76b170e Mon Sep 17 00:00:00 2001 From: Yueming Hao Date: Mon, 9 Dec 2024 10:19:56 -0800 Subject: [PATCH 4/6] fix lint --- tritonbench/utils/triton_op.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 0a5ce4fc..8cebd4f5 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -623,6 +623,7 @@ def __init__( if self.mode in [Mode.FWD_BWD, Mode.BWD]: # TODO: remove this once we have a better way to handle backward benchmarking import torch._functorch.config + torch._functorch.config.donated_buffer = False self.device = tb_args.device self.required_metrics = ( From 50b64db165afcf0a722449dd9b2295340842b78e Mon Sep 17 00:00:00 2001 From: Yueming Hao Date: Mon, 9 Dec 2024 10:37:30 -0800 Subject: [PATCH 5/6] disable torch._functorch.config.donated_buffer for fused_linear_cross_entropy, geglu, swiglu, and layernorm --- tritonbench/operators/geglu/operator.py | 6 ++++++ tritonbench/operators/swiglu/operator.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/tritonbench/operators/geglu/operator.py b/tritonbench/operators/geglu/operator.py index 9613bc96..61b790f1 100644 --- a/tritonbench/operators/geglu/operator.py +++ b/tritonbench/operators/geglu/operator.py @@ -60,6 +60,12 @@ def liger_geglu(self, input) -> Callable: @register_benchmark() def inductor_geglu(self, input) -> Callable: + # TODO: remove this once we have a better way to handle backward benchmarking + # We need to run backward multiple times for proper benchmarking + # so donated buffer have to be disabled + if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: + import torch._functorch.config + compiled = torch.compile(self.baseline_model) return lambda: compiled(input) diff --git a/tritonbench/operators/swiglu/operator.py b/tritonbench/operators/swiglu/operator.py index b21fede9..92de4aac 100644 --- a/tritonbench/operators/swiglu/operator.py +++ b/tritonbench/operators/swiglu/operator.py @@ -60,6 +60,12 @@ def liger_swiglu(self, input) -> Callable: @register_benchmark() def inductor_swiglu(self, input) -> Callable: + # TODO: remove this once we have a better way to handle backward benchmarking + # We need to run backward multiple times for proper benchmarking + # so donated buffer have to be disabled + if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: + import torch._functorch.config + compiled = torch.compile(self.baseline_op) return lambda: compiled(input) From 7024357136544f00476de234f1473f7cc7b6cc89 Mon Sep 17 00:00:00 2001 From: Yueming Hao Date: Mon, 9 Dec 2024 11:37:30 -0800 Subject: [PATCH 6/6] fix import;remove extra disablement in triton_op --- tritonbench/operators/geglu/operator.py | 1 + tritonbench/operators/layer_norm/operator.py | 9 +++++++++ tritonbench/operators/swiglu/operator.py | 1 + tritonbench/utils/triton_op.py | 5 ----- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tritonbench/operators/geglu/operator.py b/tritonbench/operators/geglu/operator.py index 61b790f1..481241b8 100644 --- a/tritonbench/operators/geglu/operator.py +++ b/tritonbench/operators/geglu/operator.py @@ -8,6 +8,7 @@ from tritonbench.utils.triton_op import ( BenchmarkOperator, + Mode, register_benchmark, register_x_val, ) diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index 6627697c..db48a03d 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -34,6 +34,15 @@ def torch_layer_norm(self, *args): @register_benchmark() def torch_compile_layer_norm(self, *args): + # TODO: remove this once we have a better way to handle backward benchmarking + # We need to run backward multiple times for proper benchmarking + # so donated buffer have to be disabled + if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: + import torch._functorch.config + + torch._functorch.config.donated_buffer = False + import torch + @torch.compile def inner(*args): return F.layer_norm(*args) diff --git a/tritonbench/operators/swiglu/operator.py b/tritonbench/operators/swiglu/operator.py index 92de4aac..f3e95bce 100644 --- a/tritonbench/operators/swiglu/operator.py +++ b/tritonbench/operators/swiglu/operator.py @@ -7,6 +7,7 @@ from tritonbench.utils.triton_op import ( BenchmarkOperator, + Mode, register_benchmark, register_x_val, ) diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 8cebd4f5..0a0aa962 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -620,11 +620,6 @@ def __init__( self.tb_args.mode == "bwd" ), "We only accept test modes: fwd, bwd, fwd_bwd, or fwd_no_grad." self.mode = Mode.BWD - if self.mode in [Mode.FWD_BWD, Mode.BWD]: - # TODO: remove this once we have a better way to handle backward benchmarking - import torch._functorch.config - - torch._functorch.config.donated_buffer = False self.device = tb_args.device self.required_metrics = ( list(set(tb_args.metrics.split(",")))