Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions tests/python/direct/test_cutlass_nvfp4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,8 @@
import pytest
import torch
from nvfuser_direct import nvf_cutlass

compute_cap = torch.cuda.get_device_capability()
if compute_cap < (10, 0) or compute_cap >= (12, 0):
pytest.skip(
reason="Nvfp4 Requires compute capability 10.",
allow_module_level=True,
)

from python.direct_utils import (
microarchitecture_is,
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_to_dtype,
Expand Down Expand Up @@ -52,6 +45,10 @@ def get_ref_results(
return torch.matmul(a_in_dtype, b_in_dtype.t())


@pytest.mark.skipif(
not microarchitecture_is(10, 0),
reason="Does not support blackwell compute 12.0, other arches are not tested.",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"shape", [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
Expand Down Expand Up @@ -100,6 +97,10 @@ def test_nvfp4_gemm(
torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)


@pytest.mark.skipif(
not microarchitecture_is(10, 0),
reason="Does not support blackwell compute 12.0, other arches are not tested.",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"shape", [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
Expand Down Expand Up @@ -175,6 +176,10 @@ def test_nvfp4_gemm_epilogue(
)


@pytest.mark.skipif(
not microarchitecture_is(10, 0),
reason="Does not support blackwell compute 12.0, other arches are not tested.",
)
@pytest.mark.parametrize("config", [[1024, 128, 256]])
@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
Expand Down
24 changes: 9 additions & 15 deletions tests/python/direct/test_narrow_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
pytorch_nvfp4_quantize,
is_pre_blackwell,
microarchitecture_is_pre,
is_blackwell,
microarchitecture_is,
linear_to_swizzled_128_4,
round_up,
activation_scale_to_nvfp4,
Expand Down Expand Up @@ -241,10 +243,8 @@ def nvfuser_fusion_id0(fd: FusionDefinition):

# cannot use opinfo test, because the input tensor dtype and fusion definition dtype doesn't match
@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(
not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0"
not microarchitecture_is(10, 0),
reason="Does not support blackwell compute 12.0, other arches are not tested.",
)
Comment on lines 245 to 248
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Inconsistent architecture filtering within same file: this test uses microarchitecture_is(10, 0) (only supports 10.0), but other tests in this file (lines 66, 173, 320) use is_blackwell() which supports all Blackwell variants (10.0, 10.3, 12.0, 12.1). Are the _scaled_mm tests specifically incompatible with 12.x architectures while quantization tests support them? Do scaled_mm operations have specific incompatibilities with 12.x compute capabilities that quantization operations don't have?

Comment on lines 245 to 248
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The skip condition logic has changed significantly. The old conditions were:

  • is_pre_blackwell() → skip if major < 10
  • not microarchitecture_is_pre(12) → skip if major >= 12

This ran tests on architectures where: 10 <= major < 12 (i.e., 10.0, 10.3, and any 11.x devices)

The new condition not microarchitecture_is(10, 0) runs tests ONLY on exactly 10.0.

If the intent is to support 10.3 (B300/GB300) in addition to 10.0, the condition should be:

Suggested change
@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(
not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0"
not microarchitecture_is(10, 0),
reason="Does not support blackwell compute 12.0, other arches are not tested.",
)
@pytest.mark.skipif(
not (microarchitecture_is(10, 0) or microarchitecture_is(10, 3)),
reason="Does not support blackwell compute 12.0, other arches are not tested.",
)

However, if 10.3 is deliberately excluded because it hasn't been tested, the current change is correct but note that this is more restrictive than the original logic.

@pytest.mark.parametrize("config", [[128, 256, 512], [128, 256, 512]])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
Expand Down Expand Up @@ -323,9 +323,7 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
torch.testing.assert_close(outputs[0], ref_outputs, rtol=1e-1, atol=1e-2)


@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(not is_blackwell(), reason="Only supported on blackwell.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent architecture check: this test uses is_blackwell() which includes compute capabilities 10.0, 10.3, 12.0, and 12.1. However:

  1. Other similar tests in this file that use fd.ops.scaled_mm (like test_scaled_mm at line 245) use microarchitecture_is(10, 0) - only 10.0
  2. The cutlass scheduler only supports 10.0 and 10.3 (not 12.0 or 12.1)

This test uses fd.ops.scaled_mm and fd.ops.nv_block_quantize, which may rely on cutlass scheduling. If so, it should use the same skip condition as other tests:

Suggested change
@pytest.mark.skipif(not is_blackwell(), reason="Only supported on blackwell.")
@pytest.mark.skipif(not microarchitecture_is(10, 0), reason="Only supported on blackwell.")

If this test is genuinely intended to support all Blackwell variants (10.0, 10.3, 12.0, 12.1) because scaled_mm uses a different scheduler (nvmmh), then the is_blackwell() function needs to be fixed to exclude 12.0 and 12.1, OR this needs documentation explaining why it's different.

@pytest.mark.parametrize("config", [[1024, 1024, 1024]])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
def test_scaled_mm_nv_quantized(
Expand Down Expand Up @@ -454,10 +452,8 @@ def fusion_baseline(fd: FusionDefinition) -> None:


@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(
not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0"
not microarchitecture_is(10, 0),
reason="Does not support blackwell compute 12.0, other arches are not tested.",
)
Comment on lines 454 to 457
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Same inconsistency: microarchitecture_is(10, 0) here vs is_blackwell() at line 320 for similar grouped matmul tests

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 454 to 457
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as test_scaled_mm: the skip condition logic has changed from supporting 10.x and 11.x architectures to only supporting exactly 10.0.

Old: is_pre_blackwell() AND not microarchitecture_is_pre(12) → run on 10 <= major < 12
New: not microarchitecture_is(10, 0) → run only on major==10 and minor==0

This test uses fd.ops.cutlass_nvfp4_grouped_mm which is explicitly a cutlass operation. The cutlass scheduler supports both 10.0 and 10.3 according to csrc/scheduler/cutlass.cpp. Consider updating to support both tested architectures if 10.3 should be included.

@pytest.mark.parametrize("config", [[1024, 128, 256]])
@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]])
Expand Down Expand Up @@ -661,10 +657,8 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
# 1. inputs data needs to be changed from `torch.testing.make_tensor` to `torch.randn`;
# 2. output errors are much more relaxed.
@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(
not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0"
not microarchitecture_is(10, 0),
reason="Does not support blackwell compute 12.0, other arches are not tested.",
)
Comment on lines 659 to 662
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Same pattern: microarchitecture_is(10, 0) here vs is_blackwell() elsewhere in file

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 659 to 662
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue: the skip condition has become more restrictive, changing from supporting 10.x and 11.x to only 10.0. This test also uses fd.ops.cutlass_nvfp4_grouped_mm which should support both 10.0 and 10.3 based on the cutlass scheduler implementation.

@pytest.mark.parametrize("config", [[1024, 128, 256]])
@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]])
Expand Down
9 changes: 3 additions & 6 deletions tests/python/direct/test_with_id_model_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
FLOAT8_E4M3_EPS,
FLOAT8_E4M3_MAX,
pytorch_nvfp4_quantize,
is_pre_blackwell,
microarchitecture_is_pre,
microarchitecture_is,
linear_to_swizzled_128_4,
round_up,
activation_scale_to_nvfp4,
Expand All @@ -31,10 +30,8 @@
# assertion. Having this as a separate test file would avoid environment
# variable contamination from others.
@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(
not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0"
not microarchitecture_is(10, 0),
reason="Does not support blackwell compute 12.0, other arches are not tested.",
)
Comment on lines 32 to 35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The skip condition has changed from supporting 10.x and 11.x architectures (via is_pre_blackwell() AND not microarchitecture_is_pre(12)) to only supporting exactly 10.0. This is consistent with the changes in test_narrow_precision.py but represents a more restrictive test coverage than before. Verify this is intentional.

@pytest.mark.parametrize("config", [[1024, 128, 256]])
@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]])
Expand Down
13 changes: 13 additions & 0 deletions tests/python/direct_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ def is_pre_blackwell():
return microarchitecture_is_pre(10)


# 10.0 (B200 and GB200)
# 10.3 (B300 and GB300)
# 12.0 (RTX PRO 6000 and RTX 50XX)
# 12.1 (DGX Spark)
def is_blackwell():
return (
microarchitecture_is(10, 0)
or microarchitecture_is(10, 3)
or microarchitecture_is(12, 0)
or microarchitecture_is(12, 1)
)
Comment on lines +42 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The is_blackwell() function incorrectly includes compute capability 12.0 and 12.1 as "Blackwell" architectures. However, the cutlass scheduler (csrc/scheduler/cutlass.cpp lines 75-84) explicitly only supports compute capabilities 10.0 and 10.3:

if (device_prop->major != 10 ||
    !(device_prop->minor == 0 || device_prop->minor == 3)) {
  // error: "Cutlass scheduler only supports GB200 and GB300 (cc 10.0 or 10.3)"
}

Compute capabilities 12.0 (RTX PRO 6000 and RTX 50XX) and 12.1 (DGX Spark) are NOT Blackwell architectures and are NOT supported by the cutlass scheduler.

This function should only return true for 10.0 and 10.3:

Suggested change
def is_blackwell():
return (
microarchitecture_is(10, 0)
or microarchitecture_is(10, 3)
or microarchitecture_is(12, 0)
or microarchitecture_is(12, 1)
)
def is_blackwell():
return (
microarchitecture_is(10, 0)
or microarchitecture_is(10, 3)
)

Alternatively, if 12.0 and 12.1 are legitimately Blackwell variants, the function should be renamed to clarify it's broader than what cutlass supports, and tests should use more specific checks.



# Get string representation for FusionDefinition
# Run captured python definition
# Check that the result of captured python definition matches original results
Expand Down