Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flip memory coalescing for last dim case #10310

Merged
merged 18 commits into from
Aug 15, 2023
Merged

Conversation

marigoold
Copy link
Contributor

@marigoold marigoold commented Aug 7, 2023

针对 dim = -1 时候访存无法合并的情况做了优化。
实现方式是先 flip 后写到 shared memory,然后统一从 shm 中顺序写到 global memory 中,此时可以合并访存。
对比:

import oneflow as flow
x = flow.randn(32, 2048).cuda()

flow._oneflow_internal.profiler.RangePush("flip_prof")
for i in range(100):
    y = flow.flip(x, [-1])
flow._oneflow_internal.profiler.RangePop()

import torch
x = torch.randn(32, 2048).cuda()

torch.cuda.nvtx.range_push("flip_prof")
for i in range(100):
    y = torch.flip(x, [-1])
torch.cuda.nvtx.range_pop()
  • nsys 结果对比

torch kernel:
image

oneflow kernel(优化后):
image

oneflow kernel(优化前):
image

@github-actions
Copy link
Contributor

github-actions bot commented Aug 7, 2023

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@marigoold marigoold enabled auto-merge (squash) August 7, 2023 06:46
@github-actions
Copy link
Contributor

github-actions bot commented Aug 7, 2023

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10310/

@github-actions
Copy link
Contributor

github-actions bot commented Aug 7, 2023

Speed stats:
GPU Name: NVIDIA GeForce RTX 3080 Ti 

❌ OneFlow resnet50 time: 43.8ms (= 4382.0ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 57.2ms (= 5723.5ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.31 (= 57.2ms / 43.8ms)

OneFlow resnet50 time: 25.9ms (= 2586.5ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 38.3ms (= 3834.2ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.48 (= 38.3ms / 25.9ms)

OneFlow resnet50 time: 18.7ms (= 3747.8ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 35.5ms (= 7092.7ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.89 (= 35.5ms / 18.7ms)

OneFlow resnet50 time: 19.2ms (= 3833.0ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 32.4ms (= 6477.1ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.69 (= 32.4ms / 19.2ms)

OneFlow resnet50 time: 17.2ms (= 3439.4ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 29.2ms (= 5848.9ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.70 (= 29.2ms / 17.2ms)

OneFlow swin dataloader time: 0.202s (= 40.326s / 200, num_workers=1)
PyTorch swin dataloader time: 0.128s (= 25.619s / 200, num_workers=1)
Relative speed: 0.635 (= 0.128s / 0.202s)

OneFlow swin dataloader time: 0.054s (= 10.783s / 200, num_workers=4)
PyTorch swin dataloader time: 0.033s (= 6.523s / 200, num_workers=4)
Relative speed: 0.605 (= 0.033s / 0.054s)

OneFlow swin dataloader time: 0.032s (= 6.300s / 200, num_workers=8)
PyTorch swin dataloader time: 0.017s (= 3.318s / 200, num_workers=8)
Relative speed: 0.527 (= 0.017s / 0.032s)

❌ OneFlow resnet50 time: 47.7ms (= 4765.0ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 64.1ms (= 6407.5ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.34 (= 64.1ms / 47.7ms)

OneFlow resnet50 time: 32.5ms (= 3254.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 44.1ms (= 4405.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.35 (= 44.1ms / 32.5ms)

OneFlow resnet50 time: 24.0ms (= 4809.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 41.2ms (= 8243.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.71 (= 41.2ms / 24.0ms)

OneFlow resnet50 time: 22.8ms (= 4557.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 36.8ms (= 7364.0ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.62 (= 36.8ms / 22.8ms)

OneFlow resnet50 time: 21.1ms (= 4217.1ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 34.1ms (= 6815.1ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.62 (= 34.1ms / 21.1ms)

@github-actions
Copy link
Contributor

github-actions bot commented Aug 7, 2023

CI failed when running job: cuda-module. PR label automerge has been removed

@github-actions github-actions bot removed the automerge label Aug 7, 2023
@github-actions
Copy link
Contributor

github-actions bot commented Aug 7, 2023

CI failed when running job: cuda-module. PR label automerge has been removed

@github-actions
Copy link
Contributor

github-actions bot commented Aug 7, 2023

CI failed when running job: cuda-module. PR label automerge has been removed

@github-actions github-actions bot removed the automerge label Aug 7, 2023
@github-actions
Copy link
Contributor

github-actions bot commented Aug 7, 2023

CI failed when running job: cuda-module. PR label automerge has been removed

@github-actions
Copy link
Contributor

github-actions bot commented Aug 7, 2023

CI failed when running job: cuda-module. PR label automerge has been removed

@github-actions
Copy link
Contributor

github-actions bot commented Aug 7, 2023

CI failed when running job: cuda-module. PR label automerge has been removed

4 similar comments
@github-actions
Copy link
Contributor

github-actions bot commented Aug 7, 2023

CI failed when running job: cuda-module. PR label automerge has been removed

@github-actions
Copy link
Contributor

github-actions bot commented Aug 7, 2023

CI failed when running job: cuda-module. PR label automerge has been removed

@github-actions
Copy link
Contributor

github-actions bot commented Aug 7, 2023

CI failed when running job: cuda-module. PR label automerge has been removed

@github-actions
Copy link
Contributor

github-actions bot commented Aug 8, 2023

CI failed when running job: cuda-module. PR label automerge has been removed

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10310/

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: NVIDIA GeForce RTX 3080 Ti 

❌ OneFlow resnet50 time: 43.8ms (= 4382.5ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 61.2ms (= 6115.3ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.40 (= 61.2ms / 43.8ms)

OneFlow resnet50 time: 26.0ms (= 2595.7ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 37.8ms (= 3777.2ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.46 (= 37.8ms / 26.0ms)

OneFlow resnet50 time: 18.7ms (= 3742.1ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 35.3ms (= 7058.7ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.89 (= 35.3ms / 18.7ms)

OneFlow resnet50 time: 18.3ms (= 3664.2ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 34.0ms (= 6793.4ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.85 (= 34.0ms / 18.3ms)

OneFlow resnet50 time: 18.9ms (= 3784.8ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 29.2ms (= 5837.6ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.54 (= 29.2ms / 18.9ms)

OneFlow swin dataloader time: 0.201s (= 40.190s / 200, num_workers=1)
PyTorch swin dataloader time: 0.129s (= 25.828s / 200, num_workers=1)
Relative speed: 0.643 (= 0.129s / 0.201s)

OneFlow swin dataloader time: 0.055s (= 10.987s / 200, num_workers=4)
PyTorch swin dataloader time: 0.033s (= 6.547s / 200, num_workers=4)
Relative speed: 0.596 (= 0.033s / 0.055s)

OneFlow swin dataloader time: 0.031s (= 6.145s / 200, num_workers=8)
PyTorch swin dataloader time: 0.017s (= 3.312s / 200, num_workers=8)
Relative speed: 0.539 (= 0.017s / 0.031s)

❌ OneFlow resnet50 time: 47.8ms (= 4783.5ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 64.1ms (= 6409.2ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.34 (= 64.1ms / 47.8ms)

OneFlow resnet50 time: 30.6ms (= 3058.8ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 43.3ms (= 4331.9ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.42 (= 43.3ms / 30.6ms)

OneFlow resnet50 time: 24.3ms (= 4858.6ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 41.8ms (= 8367.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.72 (= 41.8ms / 24.3ms)

OneFlow resnet50 time: 21.8ms (= 4369.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 38.9ms (= 7775.8ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.78 (= 38.9ms / 21.8ms)

OneFlow resnet50 time: 20.9ms (= 4180.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 34.4ms (= 6871.7ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.64 (= 34.4ms / 20.9ms)

@marigoold marigoold removed the request for review from oneflow-ci-bot August 15, 2023 01:55
@marigoold marigoold marked this pull request as draft August 15, 2023 01:55
@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10310/

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: NVIDIA GeForce RTX 3080 Ti 

❌ OneFlow resnet50 time: 43.6ms (= 4364.1ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 58.8ms (= 5883.0ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.35 (= 58.8ms / 43.6ms)

OneFlow resnet50 time: 26.1ms (= 2608.5ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 38.5ms (= 3853.6ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.48 (= 38.5ms / 26.1ms)

OneFlow resnet50 time: 18.8ms (= 3754.3ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 35.8ms (= 7154.7ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.91 (= 35.8ms / 18.8ms)

OneFlow resnet50 time: 18.9ms (= 3782.3ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 32.0ms (= 6400.6ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.69 (= 32.0ms / 18.9ms)

OneFlow resnet50 time: 17.5ms (= 3508.3ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 30.0ms (= 6005.6ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.71 (= 30.0ms / 17.5ms)

OneFlow swin dataloader time: 0.203s (= 40.613s / 200, num_workers=1)
PyTorch swin dataloader time: 0.129s (= 25.815s / 200, num_workers=1)
Relative speed: 0.636 (= 0.129s / 0.203s)

OneFlow swin dataloader time: 0.056s (= 11.164s / 200, num_workers=4)
PyTorch swin dataloader time: 0.033s (= 6.505s / 200, num_workers=4)
Relative speed: 0.583 (= 0.033s / 0.056s)

OneFlow swin dataloader time: 0.031s (= 6.195s / 200, num_workers=8)
PyTorch swin dataloader time: 0.017s (= 3.345s / 200, num_workers=8)
Relative speed: 0.540 (= 0.017s / 0.031s)

❌ OneFlow resnet50 time: 47.4ms (= 4741.2ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 66.3ms (= 6626.5ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.40 (= 66.3ms / 47.4ms)

OneFlow resnet50 time: 31.0ms (= 3103.2ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 44.9ms (= 4494.4ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.45 (= 44.9ms / 31.0ms)

OneFlow resnet50 time: 24.4ms (= 4875.6ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 42.2ms (= 8430.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.73 (= 42.2ms / 24.4ms)

OneFlow resnet50 time: 22.2ms (= 4431.4ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 36.6ms (= 7328.2ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.65 (= 36.6ms / 22.2ms)

OneFlow resnet50 time: 21.1ms (= 4214.1ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 34.0ms (= 6806.1ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.62 (= 34.0ms / 21.1ms)

@marigoold marigoold marked this pull request as ready for review August 15, 2023 02:46
@github-actions
Copy link
Contributor

Speed stats:

@marigoold marigoold enabled auto-merge (squash) August 15, 2023 03:18
@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@marigoold marigoold requested review from oneflow-ci-bot and removed request for oneflow-ci-bot August 15, 2023 03:19
@github-actions
Copy link
Contributor

CI failed when running job: cpu-module. PR label automerge has been removed

@github-actions
Copy link
Contributor

Speed stats:

@github-actions
Copy link
Contributor

CI failed when running job: cuda-module. PR label automerge has been removed

@github-actions
Copy link
Contributor

Speed stats:

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10310/

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: NVIDIA GeForce RTX 3080 Ti 

❌ OneFlow resnet50 time: 43.6ms (= 4362.4ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 57.5ms (= 5751.3ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.32 (= 57.5ms / 43.6ms)

OneFlow resnet50 time: 25.9ms (= 2589.8ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 38.8ms (= 3876.3ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.50 (= 38.8ms / 25.9ms)

OneFlow resnet50 time: 19.8ms (= 3967.2ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 36.2ms (= 7248.5ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.83 (= 36.2ms / 19.8ms)

OneFlow resnet50 time: 17.2ms (= 3433.1ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 31.8ms (= 6368.4ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.85 (= 31.8ms / 17.2ms)

OneFlow resnet50 time: 17.5ms (= 3493.4ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 28.7ms (= 5749.8ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.65 (= 28.7ms / 17.5ms)

OneFlow swin dataloader time: 0.200s (= 39.964s / 200, num_workers=1)
PyTorch swin dataloader time: 0.129s (= 25.896s / 200, num_workers=1)
Relative speed: 0.648 (= 0.129s / 0.200s)

OneFlow swin dataloader time: 0.055s (= 10.928s / 200, num_workers=4)
PyTorch swin dataloader time: 0.033s (= 6.501s / 200, num_workers=4)
Relative speed: 0.595 (= 0.033s / 0.055s)

OneFlow swin dataloader time: 0.030s (= 6.011s / 200, num_workers=8)
PyTorch swin dataloader time: 0.017s (= 3.325s / 200, num_workers=8)
Relative speed: 0.553 (= 0.017s / 0.030s)

❌ OneFlow resnet50 time: 47.6ms (= 4764.9ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 63.5ms (= 6350.8ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.33 (= 63.5ms / 47.6ms)

OneFlow resnet50 time: 31.1ms (= 3112.3ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 44.8ms (= 4477.5ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.44 (= 44.8ms / 31.1ms)

OneFlow resnet50 time: 24.3ms (= 4866.6ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 41.5ms (= 8295.7ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.70 (= 41.5ms / 24.3ms)

OneFlow resnet50 time: 22.5ms (= 4501.8ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 36.7ms (= 7342.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.63 (= 36.7ms / 22.5ms)

OneFlow resnet50 time: 21.8ms (= 4364.2ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 34.0ms (= 6798.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.56 (= 34.0ms / 21.8ms)

@marigoold marigoold merged commit 2d24fe0 into master Aug 15, 2023
20 checks passed
@marigoold marigoold deleted the optimize_flip_last_dim branch August 15, 2023 04:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants