diff --git a/tests/python/direct/test_cutlass_nvfp4_gemm.py b/tests/python/direct/test_cutlass_nvfp4_gemm.py index 8b73f384338..22cdb7aaca4 100644 --- a/tests/python/direct/test_cutlass_nvfp4_gemm.py +++ b/tests/python/direct/test_cutlass_nvfp4_gemm.py @@ -6,15 +6,8 @@ import pytest import torch from nvfuser_direct import nvf_cutlass - -compute_cap = torch.cuda.get_device_capability() -if compute_cap < (10, 0) or compute_cap >= (12, 0): - pytest.skip( - reason="Nvfp4 Requires compute capability 10.", - allow_module_level=True, - ) - from python.direct_utils import ( + microarchitecture_is, FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_to_dtype, @@ -52,6 +45,10 @@ def get_ref_results( return torch.matmul(a_in_dtype, b_in_dtype.t()) +@pytest.mark.skipif( + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", +) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( "shape", [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] @@ -100,6 +97,10 @@ def test_nvfp4_gemm( torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1) +@pytest.mark.skipif( + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", +) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( "shape", [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] @@ -175,6 +176,10 @@ def test_nvfp4_gemm_epilogue( ) +@pytest.mark.skipif( + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", +) @pytest.mark.parametrize("config", [[1024, 128, 256]]) @pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index 8e16b7abc09..591be769e56 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -16,6 +16,8 @@ pytorch_nvfp4_quantize, is_pre_blackwell, microarchitecture_is_pre, + is_blackwell, + microarchitecture_is, linear_to_swizzled_128_4, round_up, activation_scale_to_nvfp4, @@ -241,10 +243,8 @@ def nvfuser_fusion_id0(fd: FusionDefinition): # cannot use opinfo test, because the input tensor dtype and fusion definition dtype doesn't match @pytest.mark.skipif( - is_pre_blackwell(), reason="Only supported on blackwell and newer devices." -) -@pytest.mark.skipif( - not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0" + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", ) @pytest.mark.parametrize("config", [[128, 256, 512], [128, 256, 512]]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16]) @@ -323,9 +323,7 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: torch.testing.assert_close(outputs[0], ref_outputs, rtol=1e-1, atol=1e-2) -@pytest.mark.skipif( - is_pre_blackwell(), reason="Only supported on blackwell and newer devices." -) +@pytest.mark.skipif(not is_blackwell(), reason="Only supported on blackwell.") @pytest.mark.parametrize("config", [[1024, 1024, 1024]]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16]) def test_scaled_mm_nv_quantized( @@ -454,10 +452,8 @@ def fusion_baseline(fd: FusionDefinition) -> None: @pytest.mark.skipif( - is_pre_blackwell(), reason="Only supported on blackwell and newer devices." -) -@pytest.mark.skipif( - not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0" + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", ) @pytest.mark.parametrize("config", [[1024, 128, 256]]) @pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]]) @@ -661,10 +657,8 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: # 1. inputs data needs to be changed from `torch.testing.make_tensor` to `torch.randn`; # 2. output errors are much more relaxed. @pytest.mark.skipif( - is_pre_blackwell(), reason="Only supported on blackwell and newer devices." -) -@pytest.mark.skipif( - not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0" + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", ) @pytest.mark.parametrize("config", [[1024, 128, 256]]) @pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]]) diff --git a/tests/python/direct/test_with_id_model_indexer.py b/tests/python/direct/test_with_id_model_indexer.py index b251a1da190..c701f677946 100644 --- a/tests/python/direct/test_with_id_model_indexer.py +++ b/tests/python/direct/test_with_id_model_indexer.py @@ -15,8 +15,7 @@ FLOAT8_E4M3_EPS, FLOAT8_E4M3_MAX, pytorch_nvfp4_quantize, - is_pre_blackwell, - microarchitecture_is_pre, + microarchitecture_is, linear_to_swizzled_128_4, round_up, activation_scale_to_nvfp4, @@ -31,10 +30,8 @@ # assertion. Having this as a separate test file would avoid environment # variable contamination from others. @pytest.mark.skipif( - is_pre_blackwell(), reason="Only supported on blackwell and newer devices." -) -@pytest.mark.skipif( - not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0" + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", ) @pytest.mark.parametrize("config", [[1024, 128, 256]]) @pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]]) diff --git a/tests/python/direct_utils/utils.py b/tests/python/direct_utils/utils.py index f5eb652a39e..9b156a51437 100644 --- a/tests/python/direct_utils/utils.py +++ b/tests/python/direct_utils/utils.py @@ -35,6 +35,19 @@ def is_pre_blackwell(): return microarchitecture_is_pre(10) +# 10.0 (B200 and GB200) +# 10.3 (B300 and GB300) +# 12.0 (RTX PRO 6000 and RTX 50XX) +# 12.1 (DGX Spark) +def is_blackwell(): + return ( + microarchitecture_is(10, 0) + or microarchitecture_is(10, 3) + or microarchitecture_is(12, 0) + or microarchitecture_is(12, 1) + ) + + # Get string representation for FusionDefinition # Run captured python definition # Check that the result of captured python definition matches original results