diff --git a/manual_ci.sh b/manual_ci.sh index 6e4c0707869..54c7ff2b51e 100755 --- a/manual_ci.sh +++ b/manual_ci.sh @@ -17,7 +17,6 @@ run_test() { cd "$(dirname "${BASH_SOURCE[0]}")" -run_test './bin/lib/dynamic_type/test_dynamic_type_17' run_test './bin/lib/dynamic_type/test_dynamic_type_20' run_test './bin/test_nvfuser' run_test './bin/test_rng' @@ -26,16 +25,12 @@ if type -p mpirun > /dev/null then run_test mpirun -np 1 './bin/test_multidevice' fi -run_test './bin/test_view' run_test './bin/test_matmul' run_test './bin/test_external_src' -run_test './bin/tutorial' run_test './bin/test_python_frontend' run_test './bin/test_profiler' -run_test 'pytest tests/python/test_ops.py' -run_test 'pytest tests/python/test_python_frontend.py' -run_test 'pytest tests/python/test_schedule_ops.py' +run_test 'pytest tests/python/direct' if $failed_tests; then diff --git a/python/python_frontend/fusion_cache.cpp b/python/python_frontend/fusion_cache.cpp index aec0fde6019..8eef20fa253 100644 --- a/python/python_frontend/fusion_cache.cpp +++ b/python/python_frontend/fusion_cache.cpp @@ -347,6 +347,27 @@ FusionCache* FusionCache::get( if (load_from_default_workspace && fs::exists(file_path)) { try { singleton_->deserialize(file_path); + + // Check if deserialized cache exceeds max_fusions limit + if (singleton_->fusions_.size() > max_fusions) { + std::cout + << "Warning: Deserialized cache contains " + << singleton_->fusions_.size() + << " fusions, which exceeds the requested max_fusions limit of " + << max_fusions << ". Resetting cache." << std::endl; + + // Delete incompatible workspace + std::error_code remove_ec; + fs::remove(file_path, remove_ec); + if (remove_ec) { + std::cout << "Failed to delete common workspace. Exception:\t" + << remove_ec.message() << std::endl; + } + + // Reset FusionCache + delete singleton_; + singleton_ = new FusionCache(max_fusions, selected_device); + } } catch (const std::exception& deserialize_exception) { // The saved workspace can become out-of-date between nvfuser updates. // Send warning and delete the incompatible workspace. diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 31c9ac10eeb..16f29940d1e 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -94,7 +94,7 @@ TEST_F(NVFuserTest, BarSyncWarpSpecializedPointwise) { } TEST_F(NVFuserTest, RegisterSharingCircularBufferingPointwiseCustom) { - NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 11, 0); std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -2014,6 +2014,9 @@ TEST_P(TmaCircularBufferingTest, Persistent) { TEST_P(TmaCircularBufferingTest, Matmul) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + // Check shared memory requirements for matmul tests + REQUIRE_DEVICE_SMEM_SIZE(147584, 0); + if (testEnablesTIDx()) { GTEST_SKIP() << "Warp Specialization with TIDx used for both computation " "and load, requires TIDx to be a multiple of 128."; @@ -2157,6 +2160,9 @@ TEST_P(TmaCircularBufferingTest, Matmul) { TEST_P(TmaCircularBufferingTest, MatmulWithBroadcastedInput) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + // Check shared memory requirements for matmul tests + REQUIRE_DEVICE_SMEM_SIZE(147584, 0); + if (testEnablesTIDx()) { GTEST_SKIP() << "Warp Specialization with TIDx used for both computation " "and load, requires TIDx to be a multiple of 128."; diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index 32fd4f535c9..61ed15687cb 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -1067,7 +1067,13 @@ class TmaWarpSpecializedTest ASSERT_NE(heur, nullptr); ASSERT_TRUE(heur->isA()); auto* rparams = heur->as(); - EXPECT_TRUE(rparams->computation_warp_groups > 1); + + // Skip computation_warp_groups check for devices with compute capability + // 12. This heuristics check will fail due to smaller shared memory + // capacities with larger input sizes. + if (cudaArchGuardShouldSkip(12, 0, 13, 0)) { + EXPECT_TRUE(rparams->computation_warp_groups > 1); + } } protected: @@ -1327,7 +1333,7 @@ TEST(StaticWarpReductionTest, StaticWarpReductionValidation) { EnableOption::WarpSpecializedNormalization); int64_t dim0 = 2048; - int64_t dim1 = 8192; + int64_t dim1 = 4096; DataType dtype = DataType::Float; auto fusion_ptr = std::make_unique(); diff --git a/tests/cpp/test_memory.cpp b/tests/cpp/test_memory.cpp index 467f9c8baab..2086c700c84 100644 --- a/tests/cpp/test_memory.cpp +++ b/tests/cpp/test_memory.cpp @@ -558,6 +558,12 @@ TEST_P(TMALoadTestWithABroadcastDim, LoadWithBroadcast) { FusionGuard fg(&fusion); auto shape = std::get<0>(GetParam()); + uint64_t required_smem = dataTypeSizeByte(dtype); + for (auto dim : shape) + required_smem *= dim; + + REQUIRE_DEVICE_SMEM_SIZE(required_smem, 0); + auto tv0 = makeContigConcreteTensor(shape, dtype); fusion.addInput(tv0); auto tv1 = set(tv0); @@ -960,6 +966,8 @@ TEST_F(TMAIndexingTest, DefineBoxByCompositing2) { } TEST_F(TMAIndexingTest, DefineBoxByCompositingShouldNotMerge) { + REQUIRE_DEVICE_SMEM_SIZE(131080, 0); + Fusion fusion; FusionGuard fg(&fusion); @@ -1009,6 +1017,8 @@ TEST_F(TMAIndexingTest, DefineBoxByCompositingShouldNotMerge) { } TEST_F(TMAIndexingTest, DefineBoxByRotation1) { + REQUIRE_DEVICE_SMEM_SIZE(124424, 0); + Fusion fusion; FusionGuard fg(&fusion); diff --git a/tests/cpp/utils.h b/tests/cpp/utils.h index 7fa4ae57ab8..0bc9e4a6f77 100644 --- a/tests/cpp/utils.h +++ b/tests/cpp/utils.h @@ -470,8 +470,9 @@ class HopperBase : public NVFuserTest { class BlackwellBase : public NVFuserTest { protected: void SetUp() override { - if (cudaArchGuardShouldSkip(10, 0)) { - GTEST_SKIP() << "skipping tests on non-Blackwell GPUs"; + if (cudaArchGuardShouldSkip(10, 0, 11, 0)) { + GTEST_SKIP() << "skipping tests on non-Blackwell GPUs (requires " + "sm_100/sm_104, not sm_110+)"; } NVFuserTest::SetUp(); EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); diff --git a/tests/python/direct/test_cutlass_gemm.py b/tests/python/direct/test_cutlass_gemm.py index 8223daf2761..e2073ec92d6 100644 --- a/tests/python/direct/test_cutlass_gemm.py +++ b/tests/python/direct/test_cutlass_gemm.py @@ -6,12 +6,16 @@ import pytest import torch from python.direct_utils import is_pre_blackwell +from python.direct_utils import microarchitecture_is_pre from nvfuser_direct import nvf_cutlass @pytest.mark.skipif( is_pre_blackwell(), reason="Only supported on blackwell and newer devices." ) +@pytest.mark.skipif( + not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0." +) @pytest.mark.parametrize("config", [[1024, 128, 256], [32, 128, 256]]) @pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8], [5, 7, 9]]) @pytest.mark.parametrize("tensor_dtype", [torch.bfloat16, torch.float16]) diff --git a/tests/python/direct/test_cutlass_nvfp4_gemm.py b/tests/python/direct/test_cutlass_nvfp4_gemm.py index 14e3dac6059..d5f63d3cf07 100644 --- a/tests/python/direct/test_cutlass_nvfp4_gemm.py +++ b/tests/python/direct/test_cutlass_nvfp4_gemm.py @@ -7,9 +7,10 @@ import torch from nvfuser_direct import nvf_cutlass -if torch.cuda.get_device_capability() < (10, 0): +compute_cap = torch.cuda.get_device_capability() +if compute_cap < (10, 0) or compute_cap >= (12, 0): pytest.skip( - reason="Nvfp4 Requires compute capability of 10 or above.", + reason="Nvfp4 Requires compute capability 10.", allow_module_level=True, ) diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index eb725a5df8e..570e65d44ca 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -15,6 +15,7 @@ FLOAT8_E4M3_MAX, pytorch_nvfp4_quantize, is_pre_blackwell, + microarchitecture_is_pre, linear_to_swizzled_128_4, round_up, activation_scale_to_nvfp4, @@ -36,6 +37,9 @@ def nvfp4_quantize(x): @pytest.mark.skipif( is_pre_blackwell(), reason="Only supported on blackwell and newer devices." ) +@pytest.mark.skipif( + not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0" +) @pytest.mark.parametrize("config", [[128, 256, 512], [128, 256, 512]]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16]) def test_scaled_mm( @@ -114,6 +118,9 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: @pytest.mark.skipif( is_pre_blackwell(), reason="Only supported on blackwell and newer devices." ) +@pytest.mark.skipif( + not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0" +) @pytest.mark.parametrize("config", [[1024, 128, 256]]) @pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16]) diff --git a/tests/python/direct/test_repro.py b/tests/python/direct/test_repro.py index 0bfe970311f..38142e66c54 100644 --- a/tests/python/direct/test_repro.py +++ b/tests/python/direct/test_repro.py @@ -6,6 +6,7 @@ import torch import pytest from nvfuser_direct import FusionDefinition, DataType +from python.direct_utils import skip_if_global_memory_below_gb # Use smaller range for torch.testing.make_tensor for nvfuser_direct.validate LOW_VAL = -2 @@ -1637,6 +1638,7 @@ def test_issue2664_repro4(nvfuser_direct_test): T0 has a implicit broadcast which is used in add(T3) and neg (T4). T4 is used to inplace update T0, which causes RW race. """ + skip_if_global_memory_below_gb(32) def fusion_func(fd: FusionDefinition) -> None: T0 = fd.define_tensor( @@ -1662,6 +1664,7 @@ def fusion_func(fd: FusionDefinition) -> None: torch.randn((4194304, 1), dtype=torch.float32, device="cuda:0"), torch.randn((4194304, 128), dtype=torch.float32, device="cuda:0"), ] + ref_out = [inputs[0] + inputs[1], -inputs[0]] out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, inputs) @@ -1976,6 +1979,7 @@ def test_issue4444(nvfuser_direct_test): - Proper handling of tensor shapes and operations - Scalar definition and vector operations """ + skip_if_global_memory_below_gb(32) def fusion_func(fd: FusionDefinition) -> None: T0 = fd.define_tensor( @@ -2429,6 +2433,7 @@ def test_ws_tma_normalization1(nvfuser_direct_test): This test verifies complex tensor operations with BFloat16 data type, including reshape, cast, broadcast, and mathematical operations. """ + skip_if_global_memory_below_gb(32) def fusion_func(fd: FusionDefinition) -> None: T0 = fd.define_tensor( @@ -2736,6 +2741,7 @@ def test_ws_tma_normalization3(nvfuser_direct_test): This test verifies complex tensor operations with BFloat16 and Float data types, including reshape, cast, broadcast, and mathematical operations. """ + skip_if_global_memory_below_gb(32) def fusion_func(fd: FusionDefinition) -> None: T0 = fd.define_tensor( @@ -2977,6 +2983,7 @@ def test_ws_tma_normalization5(nvfuser_direct_test): This test verifies complex tensor operations with BFloat16 and Float data types, including reshape, cast, broadcast, and mathematical operations. """ + skip_if_global_memory_below_gb(32) def fusion_func(fd: FusionDefinition) -> None: T0 = fd.define_tensor( @@ -3993,6 +4000,7 @@ def test_ws_tma_normalization6(nvfuser_direct_test): This test verifies complex tensor operations with BFloat16 and Float data types, including scalar tensor operations, reshape, cast, broadcast, and mathematical operations. """ + skip_if_global_memory_below_gb(32) def fusion_func(fd: FusionDefinition) -> None: T0 = fd.define_tensor( diff --git a/tests/python/direct_utils/utils.py b/tests/python/direct_utils/utils.py index 54507c6eedd..c69c5c28ab6 100644 --- a/tests/python/direct_utils/utils.py +++ b/tests/python/direct_utils/utils.py @@ -4,6 +4,7 @@ # Owner(s): ["module: nvfuser"] import torch +import pytest from nvfuser_direct import FusionDefinition, DataType, TensorView from looseversion import LooseVersion @@ -112,3 +113,15 @@ def create_sdpa_rng_tensors() -> tuple[torch.Tensor, torch.Tensor]: philox_seed = torch.testing.make_tensor(philox_shape, device=device, dtype=dtype) philox_offset = torch.testing.make_tensor((), device=device, dtype=dtype) return philox_seed, philox_offset + + +def skip_if_global_memory_below_gb(min_gb: int, gpu_id: int = 0): + device_properties = torch.cuda.get_device_properties(gpu_id) + total_memory_bytes = device_properties.total_memory + min_bytes = min_gb * (1024**3) + + if total_memory_bytes < min_bytes: + pytest.skip( + f"Insufficient GPU global memory: requires ~{min_bytes} B, " + f"but only {total_memory_bytes} B available" + )