Skip to content

Commit

Permalink
reduce GQA test combinations (#22918)
Browse files Browse the repository at this point in the history
### Description
* Reduce GQA test combinations to save about 35 minutes test time in CI
pipelines.
* Show latency of transformers tests
* Use seed in DMMHA test to avoid random failure.
* For test_flash_attn_rocm.py, test skipping condition from "has cuda
ep" to "not has rocm ep", so that it does not run in cpu build.
* For test_flash_attn_cuda.py, move flash attention and memory efficient
attention tests to different classes, so that we can skip a test suite
instead of checking in each test.

### Motivation and Context
It takes too long to run GQA tests in CI pipelines since there are too
many combinations.

###### Linux GPU CI Pipeline
Before: 5097 passed, 68 skipped, 8 warnings in 1954.64s (0:32:34)
After:  150 passed, 176 skipped, 8 warnings in 530.38s (0:08:50)
Time Saved: **1424** seconds (0:23:44)

###### Windows GPU CUDA CI Pipeline
Before: 1781 passed, 72 skipped, 6 warnings in 605.48s (0:10:05)
After: 116 passed, 118 skipped, 6 warnings in 275.48s (0:04:35) 
Time Saved: **330** seconds (0:05:30)

###### Linux CPU CI Pipeline
Before: 5093 passed, 72 skipped, 4 warnings in 467.04s (0:07:47)
- 212.96s transformers/test_gqa_cpu.py::TestGQA::test_gqa_past
- 154.12s transformers/test_gqa_cpu.py::TestGQA::test_gqa_no_past
- 26.45s
transformers/test_gqa_cpu.py::TestGQA::test_gqa_interactive_one_batch

After: 116 passed, 210 skipped, 4 warnings in 93.41s (0:01:33)
- 0.97s  transformers/test_gqa_cpu.py::TestGQA::test_gqa_past
- 19.23s transformers/test_gqa_cpu.py::TestGQA::test_gqa_no_past
- 2.41s
transformers/test_gqa_cpu.py::TestGQA::test_gqa_interactive_one_batch

Time Saved: **374** seconds (0:06:14).
  • Loading branch information
tianleiwu authored Nov 21, 2024
1 parent 55f0559 commit 8d99b1a
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ static void TestDecoderMaskedMultiHeadAttention(bool is_cross_attn = true, bool

OpTester tester("DecoderMaskedMultiHeadAttention", 1, onnxruntime::kMSDomain);
FixedPatternValueGenerator generator{};
RandomValueGenerator random{};
RandomValueGenerator random{123};

// Attributes
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
Expand Down
170 changes: 79 additions & 91 deletions onnxruntime/test/python/transformers/test_flash_attn_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from parameterized import parameterized
from test_gqa_cpu import smooth_softmax_ref

from onnxruntime import InferenceSession, OrtValue, SessionOptions
from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers

torch.manual_seed(0)

Expand Down Expand Up @@ -1999,6 +1999,8 @@ def parity_check_gqa_past_no_buff(
def has_flash_attention():
if not torch.cuda.is_available():
return False
if "CUDAExecutionProvider" not in get_available_providers():
return False
major, _ = torch.cuda.get_device_capability()
return major >= 8 and (
platform.system() == "Linux"
Expand All @@ -2009,6 +2011,8 @@ def has_flash_attention():
def has_memory_efficient():
if not torch.cuda.is_available():
return False
if "CUDAExecutionProvider" not in get_available_providers():
return False
major, minor = torch.cuda.get_device_capability()
if major < 5 or (major == 5 and minor < 3):
return False
Expand Down Expand Up @@ -2047,8 +2051,8 @@ def mha_test_cases():
(2048, 2048),
]
)
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [3] if pipeline_mode else [1, 6, 16]
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]

for b in batches:
for s, s2 in seqs:
Expand Down Expand Up @@ -2080,11 +2084,7 @@ def gqa_no_past_memory_efficient_test_cases():
batches = [3] if pipeline_mode else [1, 3, 5]
seqs = (
[
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
if pipeline_mode
else [
Expand All @@ -2095,8 +2095,8 @@ def gqa_no_past_memory_efficient_test_cases():
(240, 240),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
torch.manual_seed(69)

for b in batches:
Expand All @@ -2121,10 +2121,6 @@ def gqa_no_past_flash_attention_test_cases():
batches = [3] if pipeline_mode else [1, 3, 5]
seqs = (
[
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
if pipeline_mode
Expand All @@ -2136,8 +2132,8 @@ def gqa_no_past_flash_attention_test_cases():
(240, 240),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
torch.manual_seed(69)

for b in batches:
Expand All @@ -2163,7 +2159,7 @@ def gqa_no_past_flash_attention_test_cases():
def gqa_past_memory_efficient_test_cases():
batches = [5] if pipeline_mode else [1, 3, 5]
seqs = (
[(1, 128), (1, 1024), (1, 2048)]
[(1, 1024)]
if pipeline_mode
else [
(1, 128),
Expand All @@ -2179,8 +2175,8 @@ def gqa_past_memory_efficient_test_cases():
# (128, 128),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)

for b in batches:
Expand All @@ -2205,7 +2201,7 @@ def gqa_past_memory_efficient_test_cases():
def gqa_past_flash_attention_test_cases():
batches = [5] if pipeline_mode else [1, 3, 5]
seqs = (
[(1, 128), (1, 1024), (1, 2048)]
[(1, 2048)]
if pipeline_mode
else [
(1, 128),
Expand All @@ -2221,8 +2217,8 @@ def gqa_past_flash_attention_test_cases():
# (128, 128),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)

for b in batches:
Expand All @@ -2249,7 +2245,7 @@ def gqa_past_flash_attention_test_cases():
def gqa_interactive_one_batch_flash_attention_test_cases():
batches = [1]
seqs = (
[(2, 128), (128, 129), (32, 128), (256, 2048)]
[(128, 2048)]
if pipeline_mode
else [
(1, 128),
Expand All @@ -2265,8 +2261,8 @@ def gqa_interactive_one_batch_flash_attention_test_cases():
# (128, 128),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)

for b in batches:
Expand All @@ -2290,7 +2286,7 @@ def gqa_interactive_one_batch_flash_attention_test_cases():
def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
batches = [1]
seqs = (
[(2, 128), (128, 129), (32, 128), (256, 2048)]
[(32, 128)]
if pipeline_mode
else [
(1, 128),
Expand All @@ -2306,8 +2302,8 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
# (128, 128),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)

for b in batches:
Expand All @@ -2326,159 +2322,151 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
)


class TestGQA(unittest.TestCase):
@parameterized.expand(gqa_no_past_memory_efficient_test_cases())
def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
if not has_memory_efficient():
return
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------")
@unittest.skipIf(not has_flash_attention(), reason="Flash Attention is not available, skipping tests.")
class TestFlashGQA(unittest.TestCase):
@parameterized.expand(gqa_no_past_flash_attention_test_cases())
def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
print("------- FLASH ATTENTION (PROMPT CASE) --------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"

parity_check_gqa_prompt(
config,
rtol=5e-3,
atol=5e-3,
local=local,
past_format=Formats.BNSH,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=False,
use_smooth_softmax=True,
)
parity_check_gqa_prompt_no_buff(
config,
rtol=5e-3,
atol=5e-3,
local=local,
past_format=Formats.BNSH,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=True,
use_smooth_softmax=False,
)

@parameterized.expand(gqa_no_past_flash_attention_test_cases())
def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
if not has_flash_attention():
return
print("------- FLASH ATTENTION (PROMPT CASE) --------")
@parameterized.expand(gqa_past_flash_attention_test_cases())
def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
print("------- FLASH ATTENTION (TOKEN GEN) -------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"

parity_check_gqa_prompt(
parity_check_gqa_past(
config,
local=local,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=True,
use_smooth_softmax=False,
)
parity_check_gqa_prompt_no_buff(
parity_check_gqa_past_no_buff(
config,
local=local,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=False,
use_smooth_softmax=True,
)

@parameterized.expand(gqa_past_memory_efficient_test_cases())
def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
if not has_memory_efficient():
return
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")
@parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases())
def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed):
print("------- FLASH ATTENTION (INTERACTIVE) -------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"

parity_check_gqa_past(
config,
local=local,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rtol=5e-3,
atol=5e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=True,
)
parity_check_gqa_past_no_buff(
config,
local=local,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rtol=5e-3,
atol=5e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=False,
)

@parameterized.expand(gqa_past_flash_attention_test_cases())
def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
if not has_flash_attention():
return
print("------- FLASH ATTENTION (TOKEN GEN) -------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"

parity_check_gqa_past(
@unittest.skipIf(not has_memory_efficient(), reason="Memory efficient FMHA is not available, skipping tests.")
class TestMemoryEfficientGQA(unittest.TestCase):
@parameterized.expand(gqa_no_past_memory_efficient_test_cases())
def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------")

parity_check_gqa_prompt(
config,
local=local,
rtol=5e-3,
atol=5e-3,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=False,
)
parity_check_gqa_past_no_buff(
parity_check_gqa_prompt_no_buff(
config,
local=local,
rtol=5e-3,
atol=5e-3,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=True,
)

@parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases())
def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed):
if not has_flash_attention():
return
print("------- FLASH ATTENTION (INTERACTIVE) -------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
@parameterized.expand(gqa_past_memory_efficient_test_cases())
def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")

parity_check_gqa_past(
config,
local=local,
past_format=Formats.BNSH,
rtol=5e-3,
atol=5e-3,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=True,
)
parity_check_gqa_past_no_buff(
config,
local=local,
past_format=Formats.BNSH,
rtol=5e-3,
atol=5e-3,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=False,
)

@parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases())
def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed):
if not has_memory_efficient():
return
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("-------- MEMORY EFFICIENT (INTERACTIVE) --------")

Expand Down
Loading

0 comments on commit 8d99b1a

Please sign in to comment.