Skip to content

Commit

Permalink
Disable donated buffer when benchmarking layer_norm with backwards (#88)
Browse files Browse the repository at this point in the history
Summary:
Torch sets `donated_buffer = True ` by default but it does not support running backward multiple times, so we have to disable it in benchmarking.

Fixes #40

Pull Request resolved: #88

Test Plan:
```
$ python run.py --op layer_norm --bwd
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:26<00:00,  1.13it/s]
  x_val    torch_layer_norm-latency    triton_layer_norm-latency    torch_compile_layer_norm-latency    liger_layer_norm-latency
-------  --------------------------  ---------------------------  ----------------------------------  --------------------------
   1024                    0.06768                      0.0888                              0.068736                    0.068256
   1536                    0.09872                      0.090592                            0.090368                    0.08032
   2048                    0.121568                     0.100352                            0.104608                    0.088224
   2560                    0.149536                     0.107424                            0.122656                    0.097472
   3072                    0.184768                     0.116544                            0.143456                    0.124288
   3584                    0.213216                     0.127264                            0.176576                    0.117312
   4096                    0.240576                     0.1384                              0.195168                    0.123936
   4608                    0.271232                     0.180928                            0.218176                    0.179744
   5120                    0.294272                     0.191328                            0.240352                    0.185056
   5632                    0.31952                      0.199616                            0.26704                     0.197792
   6144                    0.344064                     0.208448                            0.297792                    0.21168
   6656                    0.36864                      0.219232                            0.339936                    0.219552
   7168                    0.393792                     0.226816                            0.365152                    0.22592
   7680                    0.419456                     0.240736                            0.390432                    0.236992
   8192                    0.44576                      0.251936                            0.419776                    0.25088
   8704                    0.480256                     0.264032                            0.448672                    2.67574
   9216                    0.502624                     0.274272                            0.477312                    2.72173
   9728                    0.527168                     0.293152                            0.522656                    2.7551
  10240                    0.554528                     0.30736                             0.549216                    2.78102
  10752                    0.576192                     0.325824                            0.573888                    2.8047
  11264                    0.601088                     0.339392                            0.598272                    2.84749
  11776                    0.635232                     0.351808                            0.631072                    2.8816
  12288                    0.653088                     0.36336                             0.655776                    2.92502
  12800                    0.684352                     0.381472                            0.696512                    2.95024
  13312                    0.708384                     0.391296                            0.720288                    2.97734
  13824                    0.73264                      0.406944                            0.743584                    3.00829
  14336                    0.756224                     0.417472                            0.771136                    3.04874
  14848                    0.781728                     0.434144                            0.79568                     3.07536
  15360                    0.806656                     0.432192                            0.82064                     3.10083
  15872                    0.833216                     0.459456                            0.858624                    3.10598
```

Reviewed By: FindHao

Differential Revision: D66667947

Pulled By: xuzhao9

fbshipit-source-id: 14f9304fb3684881b5d0f91635f1cde58b6fcc8e
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Dec 2, 2024
1 parent c509d84 commit 2474f1e
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tritonbench/operators/layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def torch_layer_norm(self, *args):

@register_benchmark()
def torch_compile_layer_norm(self, *args):
# We need to run backward multiple times for proper benchmarking
# so donated buffer have to be disabled
if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD:
import torch._functorch.config

torch._functorch.config.donated_buffer = False
import torch

@torch.compile
def inner(*args):
return F.layer_norm(*args)
Expand Down

0 comments on commit 2474f1e

Please sign in to comment.