Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable donated buffer when benchmarking layer_norm with backwards #88

Closed
wants to merge 3 commits into from

Conversation

xuzhao9
Copy link
Contributor

@xuzhao9 xuzhao9 commented Dec 2, 2024

Torch sets donated_buffer = True by default but it does not support running backward multiple times, so we have to disable it in benchmarking.

Fixes #40

Test plan:

$ python run.py --op layer_norm --bwd
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:26<00:00,  1.13it/s]
  x_val    torch_layer_norm-latency    triton_layer_norm-latency    torch_compile_layer_norm-latency    liger_layer_norm-latency
-------  --------------------------  ---------------------------  ----------------------------------  --------------------------
   1024                    0.06768                      0.0888                              0.068736                    0.068256
   1536                    0.09872                      0.090592                            0.090368                    0.08032
   2048                    0.121568                     0.100352                            0.104608                    0.088224
   2560                    0.149536                     0.107424                            0.122656                    0.097472
   3072                    0.184768                     0.116544                            0.143456                    0.124288
   3584                    0.213216                     0.127264                            0.176576                    0.117312
   4096                    0.240576                     0.1384                              0.195168                    0.123936
   4608                    0.271232                     0.180928                            0.218176                    0.179744
   5120                    0.294272                     0.191328                            0.240352                    0.185056
   5632                    0.31952                      0.199616                            0.26704                     0.197792
   6144                    0.344064                     0.208448                            0.297792                    0.21168
   6656                    0.36864                      0.219232                            0.339936                    0.219552
   7168                    0.393792                     0.226816                            0.365152                    0.22592
   7680                    0.419456                     0.240736                            0.390432                    0.236992
   8192                    0.44576                      0.251936                            0.419776                    0.25088
   8704                    0.480256                     0.264032                            0.448672                    2.67574
   9216                    0.502624                     0.274272                            0.477312                    2.72173
   9728                    0.527168                     0.293152                            0.522656                    2.7551
  10240                    0.554528                     0.30736                             0.549216                    2.78102
  10752                    0.576192                     0.325824                            0.573888                    2.8047
  11264                    0.601088                     0.339392                            0.598272                    2.84749
  11776                    0.635232                     0.351808                            0.631072                    2.8816
  12288                    0.653088                     0.36336                             0.655776                    2.92502
  12800                    0.684352                     0.381472                            0.696512                    2.95024
  13312                    0.708384                     0.391296                            0.720288                    2.97734
  13824                    0.73264                      0.406944                            0.743584                    3.00829
  14336                    0.756224                     0.417472                            0.771136                    3.04874
  14848                    0.781728                     0.434144                            0.79568                     3.07536
  15360                    0.806656                     0.432192                            0.82064                     3.10083
  15872                    0.833216                     0.459456                            0.858624                    3.10598

@facebook-github-bot
Copy link
Contributor

@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Member

@FindHao FindHao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fixing!

@FindHao
Copy link
Member

FindHao commented Dec 2, 2024

do you think we should add a full op replay mode for bwd testing? like we run the fwd_bwd but only measure the backward phase's performance for bwd testings.

@facebook-github-bot
Copy link
Contributor

@xuzhao9 merged this pull request in 2474f1e.

@xuzhao9 xuzhao9 deleted the xz9/fix-layernorm branch December 3, 2024 03:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

layer_norm backward problem
3 participants