diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc index 17685ab82f0ef..208545eacf224 100644 --- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc @@ -757,7 +757,7 @@ static void TestDecoderMaskedMultiHeadAttention(bool is_cross_attn = true, bool OpTester tester("DecoderMaskedMultiHeadAttention", 1, onnxruntime::kMSDomain); FixedPatternValueGenerator generator{}; - RandomValueGenerator random{}; + RandomValueGenerator random{123}; // Attributes tester.AddAttribute("num_heads", static_cast(num_heads)); diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index 46ab905977f48..a74d5389e9047 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -24,7 +24,7 @@ from parameterized import parameterized from test_gqa_cpu import smooth_softmax_ref -from onnxruntime import InferenceSession, OrtValue, SessionOptions +from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers torch.manual_seed(0) @@ -1999,6 +1999,8 @@ def parity_check_gqa_past_no_buff( def has_flash_attention(): if not torch.cuda.is_available(): return False + if "CUDAExecutionProvider" not in get_available_providers(): + return False major, _ = torch.cuda.get_device_capability() return major >= 8 and ( platform.system() == "Linux" @@ -2009,6 +2011,8 @@ def has_flash_attention(): def has_memory_efficient(): if not torch.cuda.is_available(): return False + if "CUDAExecutionProvider" not in get_available_providers(): + return False major, minor = torch.cuda.get_device_capability() if major < 5 or (major == 5 and minor < 3): return False @@ -2047,8 +2051,8 @@ def mha_test_cases(): (2048, 2048), ] ) - num_h = [1, 3] if pipeline_mode else [1, 6, 16] - h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [3] if pipeline_mode else [1, 6, 16] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] for b in batches: for s, s2 in seqs: @@ -2080,11 +2084,7 @@ def gqa_no_past_memory_efficient_test_cases(): batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( [ - (127, 127), - (35, 35), (2000, 2000), - (200, 200), - (240, 240), ] if pipeline_mode else [ @@ -2095,8 +2095,8 @@ def gqa_no_past_memory_efficient_test_cases(): (240, 240), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] torch.manual_seed(69) for b in batches: @@ -2121,10 +2121,6 @@ def gqa_no_past_flash_attention_test_cases(): batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( [ - (127, 127), - (35, 35), - (2000, 2000), - (200, 200), (240, 240), ] if pipeline_mode @@ -2136,8 +2132,8 @@ def gqa_no_past_flash_attention_test_cases(): (240, 240), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] torch.manual_seed(69) for b in batches: @@ -2163,7 +2159,7 @@ def gqa_no_past_flash_attention_test_cases(): def gqa_past_memory_efficient_test_cases(): batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (1, 1024), (1, 2048)] + [(1, 1024)] if pipeline_mode else [ (1, 128), @@ -2179,8 +2175,8 @@ def gqa_past_memory_efficient_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2205,7 +2201,7 @@ def gqa_past_memory_efficient_test_cases(): def gqa_past_flash_attention_test_cases(): batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (1, 1024), (1, 2048)] + [(1, 2048)] if pipeline_mode else [ (1, 128), @@ -2221,8 +2217,8 @@ def gqa_past_flash_attention_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2249,7 +2245,7 @@ def gqa_past_flash_attention_test_cases(): def gqa_interactive_one_batch_flash_attention_test_cases(): batches = [1] seqs = ( - [(2, 128), (128, 129), (32, 128), (256, 2048)] + [(128, 2048)] if pipeline_mode else [ (1, 128), @@ -2265,8 +2261,8 @@ def gqa_interactive_one_batch_flash_attention_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2290,7 +2286,7 @@ def gqa_interactive_one_batch_flash_attention_test_cases(): def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): batches = [1] seqs = ( - [(2, 128), (128, 129), (32, 128), (256, 2048)] + [(32, 128)] if pipeline_mode else [ (1, 128), @@ -2306,8 +2302,8 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2326,120 +2322,114 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): ) -class TestGQA(unittest.TestCase): - @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) - def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): - if not has_memory_efficient(): - return - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") +@unittest.skipIf(not has_flash_attention(), reason="Flash Attention is not available, skipping tests.") +class TestFlashGQA(unittest.TestCase): + @parameterized.expand(gqa_no_past_flash_attention_test_cases()) + def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): + print("------- FLASH ATTENTION (PROMPT CASE) --------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_prompt( config, - rtol=5e-3, - atol=5e-3, + local=local, past_format=Formats.BNSH, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, - use_smooth_softmax=False, + use_smooth_softmax=True, ) parity_check_gqa_prompt_no_buff( config, - rtol=5e-3, - atol=5e-3, + local=local, past_format=Formats.BNSH, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, - use_smooth_softmax=True, + use_smooth_softmax=False, ) - @parameterized.expand(gqa_no_past_flash_attention_test_cases()) - def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - if not has_flash_attention(): - return - print("------- FLASH ATTENTION (PROMPT CASE) --------") + @parameterized.expand(gqa_past_flash_attention_test_cases()) + def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): + print("------- FLASH ATTENTION (TOKEN GEN) -------") os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_prompt( + parity_check_gqa_past( config, local=local, past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, - use_smooth_softmax=True, + use_smooth_softmax=False, ) - parity_check_gqa_prompt_no_buff( + parity_check_gqa_past_no_buff( config, local=local, past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, - use_smooth_softmax=False, + use_smooth_softmax=True, ) - @parameterized.expand(gqa_past_memory_efficient_test_cases()) - def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): - if not has_memory_efficient(): - return - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") + @parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases()) + def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + print("------- FLASH ATTENTION (INTERACTIVE) -------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_past( config, + local=local, past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, + rtol=5e-3, + atol=5e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - softcap=softcap, - use_smooth_softmax=True, ) parity_check_gqa_past_no_buff( config, + local=local, past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, + rtol=5e-3, + atol=5e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - softcap=softcap, - use_smooth_softmax=False, ) - @parameterized.expand(gqa_past_flash_attention_test_cases()) - def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - if not has_flash_attention(): - return - print("------- FLASH ATTENTION (TOKEN GEN) -------") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_past( +@unittest.skipIf(not has_memory_efficient(), reason="Memory efficient FMHA is not available, skipping tests.") +class TestMemoryEfficientGQA(unittest.TestCase): + @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) + def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") + + parity_check_gqa_prompt( config, - local=local, + rtol=5e-3, + atol=5e-3, past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, use_smooth_softmax=False, ) - parity_check_gqa_past_no_buff( + parity_check_gqa_prompt_no_buff( config, - local=local, + rtol=5e-3, + atol=5e-3, past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2447,38 +2437,36 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle use_smooth_softmax=True, ) - @parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases()) - def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): - if not has_flash_attention(): - return - print("------- FLASH ATTENTION (INTERACTIVE) -------") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + @parameterized.expand(gqa_past_memory_efficient_test_cases()) + def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") parity_check_gqa_past( config, - local=local, past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, + rtol=1e-3, + atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, + use_smooth_softmax=True, ) parity_check_gqa_past_no_buff( config, - local=local, past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, + rtol=1e-3, + atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, + use_smooth_softmax=False, ) @parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases()) def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed): - if not has_memory_efficient(): - return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" print("-------- MEMORY EFFICIENT (INTERACTIVE) --------") diff --git a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py index 99460722c2469..a5910c28c2975 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py @@ -16,16 +16,16 @@ import onnxruntime -class TestGQA(unittest.TestCase): +@unittest.skipIf( + (not torch.cuda.is_available()) + or (platform.system() != "Linux") + or ("ROCMExecutionProvider" not in onnxruntime.get_available_providers()), + reason="ROCm is not available, skipping tests.", +) +class TestRocmGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_flash_attention_test_cases()) def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): config.ep = "ROCMExecutionProvider" - if not torch.cuda.is_available(): - return - if platform.system() != "Linux": - return - if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - return print("------- FLASH ATTENTION (PROMPT CASE) --------") parity_check_gqa_prompt( @@ -52,12 +52,6 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte @parameterized.expand(gqa_past_flash_attention_test_cases()) def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): config.ep = "ROCMExecutionProvider" - if not torch.cuda.is_available(): - return - if platform.system() != "Linux": - return - if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - return print("------- FLASH ATTENTION (TOKEN GEN) -------") parity_check_gqa_past( diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 08ec5de328b9d..77b4b326bf645 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -1900,7 +1900,7 @@ class TestGQA(unittest.TestCase): def test_gqa_no_past(self): torch.manual_seed(69) print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") - batches = [1, 3] if pipeline_mode else [1, 3, 5] + batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( [ (127, 127), @@ -1916,8 +1916,8 @@ def test_gqa_no_past(self): (8000, 8000), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] for b in batches: for sq, skv in seqs: for n, n2 in num_h: @@ -1954,9 +1954,9 @@ def test_gqa_no_past(self): def test_gqa_past(self): print("-------- TEST GQA PAST (TOKEN GEN) ---------") - batches = [1, 3] if pipeline_mode else [1, 3, 5] + batches = [1] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (1, 1024), (1, 2048)] + [(1, 128)] if pipeline_mode else [ (1, 128), @@ -1972,8 +1972,8 @@ def test_gqa_past(self): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: for s, s2 in seqs: @@ -2018,7 +2018,7 @@ def test_gqa_interactive_one_batch(self): print("-------- TEST GQA INTERACTIVE ---------") batches = [1] seqs = ( - [(2, 128), (128, 129), (32, 128), (256, 2048)] + [(256, 2048)] if pipeline_mode else [ (1, 128), @@ -2034,8 +2034,8 @@ def test_gqa_interactive_one_batch(self): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: for s, s2 in seqs: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index aa1198102f978..3bfbc01086cd3 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2149,7 +2149,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): ], cwd=SCRIPT_DIR, ) - run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd) + run_subprocess([sys.executable, "-m", "pytest", "--durations=0", "transformers"], cwd=cwd) # Restore initial numpy/protobuf version in case other tests use it run_subprocess([sys.executable, "-m", "pip", "install", "numpy==" + numpy_init_version]) run_subprocess([sys.executable, "-m", "pip", "install", "protobuf==" + pb_init_version])