Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions manual_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ run_test() {

cd "$(dirname "${BASH_SOURCE[0]}")"

run_test './bin/lib/dynamic_type/test_dynamic_type_17'
run_test './bin/lib/dynamic_type/test_dynamic_type_20'
run_test './bin/test_nvfuser'
run_test './bin/test_rng'
Expand All @@ -26,16 +25,12 @@ if type -p mpirun > /dev/null
then
run_test mpirun -np 1 './bin/test_multidevice'
fi
run_test './bin/test_view'
run_test './bin/test_matmul'
run_test './bin/test_external_src'
run_test './bin/tutorial'
run_test './bin/test_python_frontend'
run_test './bin/test_profiler'

run_test 'pytest tests/python/test_ops.py'
run_test 'pytest tests/python/test_python_frontend.py'
run_test 'pytest tests/python/test_schedule_ops.py'
run_test 'pytest tests/python/direct'

if $failed_tests;
then
Expand Down
21 changes: 21 additions & 0 deletions python/python_frontend/fusion_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,27 @@ FusionCache* FusionCache::get(
if (load_from_default_workspace && fs::exists(file_path)) {
try {
singleton_->deserialize(file_path);

// Check if deserialized cache exceeds max_fusions limit
if (singleton_->fusions_.size() > max_fusions) {
std::cout
<< "Warning: Deserialized cache contains "
<< singleton_->fusions_.size()
<< " fusions, which exceeds the requested max_fusions limit of "
<< max_fusions << ". Resetting cache." << std::endl;

// Delete incompatible workspace
std::error_code remove_ec;
fs::remove(file_path, remove_ec);
if (remove_ec) {
std::cout << "Failed to delete common workspace. Exception:\t"
<< remove_ec.message() << std::endl;
}

// Reset FusionCache
delete singleton_;
singleton_ = new FusionCache(max_fusions, selected_device);
}
} catch (const std::exception& deserialize_exception) {
// The saved workspace can become out-of-date between nvfuser updates.
// Send warning and delete the incompatible workspace.
Expand Down
8 changes: 7 additions & 1 deletion tests/cpp/test_circular_buffering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ TEST_F(NVFuserTest, BarSyncWarpSpecializedPointwise) {
}

TEST_F(NVFuserTest, RegisterSharingCircularBufferingPointwiseCustom) {
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 11, 0);
std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand Down Expand Up @@ -2014,6 +2014,9 @@ TEST_P(TmaCircularBufferingTest, Persistent) {
TEST_P(TmaCircularBufferingTest, Matmul) {
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);

// Check shared memory requirements for matmul tests
REQUIRE_DEVICE_SMEM_SIZE(147584, 0);

if (testEnablesTIDx()) {
GTEST_SKIP() << "Warp Specialization with TIDx used for both computation "
"and load, requires TIDx to be a multiple of 128.";
Expand Down Expand Up @@ -2157,6 +2160,9 @@ TEST_P(TmaCircularBufferingTest, Matmul) {
TEST_P(TmaCircularBufferingTest, MatmulWithBroadcastedInput) {
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);

// Check shared memory requirements for matmul tests
REQUIRE_DEVICE_SMEM_SIZE(147584, 0);

if (testEnablesTIDx()) {
GTEST_SKIP() << "Warp Specialization with TIDx used for both computation "
"and load, requires TIDx to be a multiple of 128.";
Expand Down
10 changes: 8 additions & 2 deletions tests/cpp/test_combined_inner_outer_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,13 @@ class TmaWarpSpecializedTest
ASSERT_NE(heur, nullptr);
ASSERT_TRUE(heur->isA<ReductionParams>());
auto* rparams = heur->as<ReductionParams>();
EXPECT_TRUE(rparams->computation_warp_groups > 1);

// Skip computation_warp_groups check for devices with compute capability
// 12. This heuristics check will fail due to smaller shared memory
// capacities with larger input sizes.
if (cudaArchGuardShouldSkip(12, 0, 13, 0)) {
EXPECT_TRUE(rparams->computation_warp_groups > 1);
}
}

protected:
Expand Down Expand Up @@ -1327,7 +1333,7 @@ TEST(StaticWarpReductionTest, StaticWarpReductionValidation) {
EnableOption::WarpSpecializedNormalization);

int64_t dim0 = 2048;
int64_t dim1 = 8192;
int64_t dim1 = 4096;
DataType dtype = DataType::Float;

auto fusion_ptr = std::make_unique<Fusion>();
Expand Down
10 changes: 10 additions & 0 deletions tests/cpp/test_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,12 @@ TEST_P(TMALoadTestWithABroadcastDim, LoadWithBroadcast) {
FusionGuard fg(&fusion);
auto shape = std::get<0>(GetParam());

uint64_t required_smem = dataTypeSizeByte(dtype);
for (auto dim : shape)
required_smem *= dim;

REQUIRE_DEVICE_SMEM_SIZE(required_smem, 0);

auto tv0 = makeContigConcreteTensor(shape, dtype);
fusion.addInput(tv0);
auto tv1 = set(tv0);
Expand Down Expand Up @@ -960,6 +966,8 @@ TEST_F(TMAIndexingTest, DefineBoxByCompositing2) {
}

TEST_F(TMAIndexingTest, DefineBoxByCompositingShouldNotMerge) {
REQUIRE_DEVICE_SMEM_SIZE(131080, 0);

Fusion fusion;
FusionGuard fg(&fusion);

Expand Down Expand Up @@ -1009,6 +1017,8 @@ TEST_F(TMAIndexingTest, DefineBoxByCompositingShouldNotMerge) {
}

TEST_F(TMAIndexingTest, DefineBoxByRotation1) {
REQUIRE_DEVICE_SMEM_SIZE(124424, 0);

Fusion fusion;
FusionGuard fg(&fusion);

Expand Down
5 changes: 3 additions & 2 deletions tests/cpp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,9 @@ class HopperBase : public NVFuserTest {
class BlackwellBase : public NVFuserTest {
protected:
void SetUp() override {
if (cudaArchGuardShouldSkip(10, 0)) {
GTEST_SKIP() << "skipping tests on non-Blackwell GPUs";
if (cudaArchGuardShouldSkip(10, 0, 11, 0)) {
GTEST_SKIP() << "skipping tests on non-Blackwell GPUs (requires "
"sm_100/sm_104, not sm_110+)";
}
NVFuserTest::SetUp();
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
Expand Down
4 changes: 4 additions & 0 deletions tests/python/direct/test_cutlass_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
import pytest
import torch
from python.direct_utils import is_pre_blackwell
from python.direct_utils import microarchitecture_is_pre
from nvfuser_direct import nvf_cutlass


@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(
not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0."
)
@pytest.mark.parametrize("config", [[1024, 128, 256], [32, 128, 256]])
@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8], [5, 7, 9]])
@pytest.mark.parametrize("tensor_dtype", [torch.bfloat16, torch.float16])
Expand Down
5 changes: 3 additions & 2 deletions tests/python/direct/test_cutlass_nvfp4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import torch
from nvfuser_direct import nvf_cutlass

if torch.cuda.get_device_capability() < (10, 0):
compute_cap = torch.cuda.get_device_capability()
if compute_cap < (10, 0) or compute_cap >= (12, 0):
pytest.skip(
reason="Nvfp4 Requires compute capability of 10 or above.",
reason="Nvfp4 Requires compute capability 10.",
allow_module_level=True,
)

Expand Down
7 changes: 7 additions & 0 deletions tests/python/direct/test_narrow_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
FLOAT8_E4M3_MAX,
pytorch_nvfp4_quantize,
is_pre_blackwell,
microarchitecture_is_pre,
linear_to_swizzled_128_4,
round_up,
activation_scale_to_nvfp4,
Expand All @@ -36,6 +37,9 @@ def nvfp4_quantize(x):
@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(
not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0"
)
@pytest.mark.parametrize("config", [[128, 256, 512], [128, 256, 512]])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
def test_scaled_mm(
Expand Down Expand Up @@ -114,6 +118,9 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(
not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0"
)
@pytest.mark.parametrize("config", [[1024, 128, 256]])
@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
Expand Down
8 changes: 8 additions & 0 deletions tests/python/direct/test_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import pytest
from nvfuser_direct import FusionDefinition, DataType
from python.direct_utils import skip_if_global_memory_below_gb

# Use smaller range for torch.testing.make_tensor for nvfuser_direct.validate
LOW_VAL = -2
Expand Down Expand Up @@ -1637,6 +1638,7 @@ def test_issue2664_repro4(nvfuser_direct_test):
T0 has a implicit broadcast which is used in add(T3) and neg (T4). T4 is
used to inplace update T0, which causes RW race.
"""
skip_if_global_memory_below_gb(32)

def fusion_func(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
Expand All @@ -1662,6 +1664,7 @@ def fusion_func(fd: FusionDefinition) -> None:
torch.randn((4194304, 1), dtype=torch.float32, device="cuda:0"),
torch.randn((4194304, 128), dtype=torch.float32, device="cuda:0"),
]

ref_out = [inputs[0] + inputs[1], -inputs[0]]
out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, inputs)

Expand Down Expand Up @@ -1976,6 +1979,7 @@ def test_issue4444(nvfuser_direct_test):
- Proper handling of tensor shapes and operations
- Scalar definition and vector operations
"""
skip_if_global_memory_below_gb(32)

def fusion_func(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
Expand Down Expand Up @@ -2429,6 +2433,7 @@ def test_ws_tma_normalization1(nvfuser_direct_test):
This test verifies complex tensor operations with BFloat16 data type,
including reshape, cast, broadcast, and mathematical operations.
"""
skip_if_global_memory_below_gb(32)

def fusion_func(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
Expand Down Expand Up @@ -2736,6 +2741,7 @@ def test_ws_tma_normalization3(nvfuser_direct_test):
This test verifies complex tensor operations with BFloat16 and Float data types,
including reshape, cast, broadcast, and mathematical operations.
"""
skip_if_global_memory_below_gb(32)

def fusion_func(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
Expand Down Expand Up @@ -2977,6 +2983,7 @@ def test_ws_tma_normalization5(nvfuser_direct_test):
This test verifies complex tensor operations with BFloat16 and Float data types,
including reshape, cast, broadcast, and mathematical operations.
"""
skip_if_global_memory_below_gb(32)

def fusion_func(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
Expand Down Expand Up @@ -3993,6 +4000,7 @@ def test_ws_tma_normalization6(nvfuser_direct_test):
This test verifies complex tensor operations with BFloat16 and Float data types,
including scalar tensor operations, reshape, cast, broadcast, and mathematical operations.
"""
skip_if_global_memory_below_gb(32)

def fusion_func(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
Expand Down
13 changes: 13 additions & 0 deletions tests/python/direct_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Owner(s): ["module: nvfuser"]

import torch
import pytest
from nvfuser_direct import FusionDefinition, DataType, TensorView
from looseversion import LooseVersion

Expand Down Expand Up @@ -112,3 +113,15 @@ def create_sdpa_rng_tensors() -> tuple[torch.Tensor, torch.Tensor]:
philox_seed = torch.testing.make_tensor(philox_shape, device=device, dtype=dtype)
philox_offset = torch.testing.make_tensor((), device=device, dtype=dtype)
return philox_seed, philox_offset


def skip_if_global_memory_below_gb(min_gb: int, gpu_id: int = 0):
device_properties = torch.cuda.get_device_properties(gpu_id)
total_memory_bytes = device_properties.total_memory
min_bytes = min_gb * (1024**3)

if total_memory_bytes < min_bytes:
pytest.skip(
f"Insufficient GPU global memory: requires ~{min_bytes} B, "
f"but only {total_memory_bytes} B available"
)
Loading