Skip to content

Commit a929f14

Browse files
committed
skip tests where compute 12 is incompatible.
1 parent 2b11e73 commit a929f14

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

tests/python/direct/test_cutlass_nvfp4_gemm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
import torch
88
from nvfuser_direct import nvf_cutlass
99

10-
if torch.cuda.get_device_capability() < (10, 0):
10+
compute_cap = torch.cuda.get_device_capability()
11+
if compute_cap < (10, 0) or compute_cap >= (12, 0):
1112
pytest.skip(
12-
reason="Nvfp4 Requires compute capability of 10 or above.",
13+
reason="Nvfp4 Requires compute capability 10.",
1314
allow_module_level=True,
1415
)
1516

tests/python/direct/test_narrow_precision.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
FLOAT8_E4M3_MAX,
1616
pytorch_nvfp4_quantize,
1717
is_pre_blackwell,
18+
is_pre_blackwell_12,
1819
linear_to_swizzled_128_4,
1920
round_up,
2021
activation_scale_to_nvfp4,
@@ -36,6 +37,9 @@ def nvfp4_quantize(x):
3637
@pytest.mark.skipif(
3738
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
3839
)
40+
@pytest.mark.skipif(
41+
not is_pre_blackwell_12(), reason="Does not support blackwell compute 12.0"
42+
)
3943
@pytest.mark.parametrize("config", [[128, 256, 512], [128, 256, 512]])
4044
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
4145
def test_scaled_mm(
@@ -114,6 +118,9 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
114118
@pytest.mark.skipif(
115119
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
116120
)
121+
@pytest.mark.skipif(
122+
not is_pre_blackwell_12(), reason="Does not support blackwell compute 12.0"
123+
)
117124
@pytest.mark.parametrize("config", [[1024, 128, 256]])
118125
@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]])
119126
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])

tests/python/direct_utils/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ def is_pre_blackwell():
2828
return prop.major < 10
2929

3030

31+
def is_pre_blackwell_12():
32+
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
33+
return prop.major < 12
34+
35+
3136
# Get string representation for FusionDefinition
3237
# Run captured python definition
3338
# Check that the result of captured python definition matches original results

0 commit comments

Comments
 (0)