diff --git a/CMakeLists.txt b/CMakeLists.txt index 4de0fb7ebd1..35ae5f78811 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -387,6 +387,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/resize.cpp ${NVFUSER_SRCS_DIR}/scheduler/runtime_info.cpp ${NVFUSER_SRCS_DIR}/scheduler/scheduler_types.cpp + ${NVFUSER_SRCS_DIR}/scheduler/tools/cub_utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/tools/domain_map.cpp ${NVFUSER_SRCS_DIR}/scheduler/tools/inlining.cpp ${NVFUSER_SRCS_DIR}/scheduler/tools/loop_domain_scheduler.cpp @@ -417,34 +418,38 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/validator_utils.cpp ) -# Add LLVM JIT related dependencies -set(LLVM_MINIMUM_VERSION "18.1") -find_package(LLVM REQUIRED CONFIG) -if(${LLVM_VERSION} VERSION_LESS ${LLVM_MINIMUM_VERSION}) - message(FATAL_ERROR "LLVM ${LLVM_VERSION} does not meet the minimum version required: ${LLVM_MINIMUM_VERSION}") -endif() -llvm_map_components_to_libnames(LLVM_LIBS - support - core - orcjit - executionengine - irreader - nativecodegen - Target - Analysis - JITLink - Demangle -) +cmake_dependent_option(NVFUSER_HOST_IR_JIT "Build nvFuser with LLVM" ON "USE_HOST_IR_JIT" OFF) + + +message(STATUS "Setting NVFUSER_HOST_IR_JIT=${NVFUSER_HOST_IR_JIT}") + +if(NVFUSER_HOST_IR_JIT) + add_compile_definitions(NVFUSER_HOST_IR_JIT) + # Add LLVM JIT related dependencies + find_package(LLVM 18.1 REQUIRED CONFIG) + llvm_map_components_to_libnames(LLVM_LIBS + support + core + orcjit + executionengine + irreader + nativecodegen + Target + Analysis + JITLink + Demangle + ) -add_library(LLVM_JIT INTERFACE) -target_include_directories(LLVM_JIT SYSTEM INTERFACE ${LLVM_INCLUDE_DIRS}) -target_compile_definitions(LLVM_JIT INTERFACE ${LLVM_DEFINITIONS}) -target_link_libraries(LLVM_JIT INTERFACE ${LLVM_LIBS}) + add_library(LLVM_JIT INTERFACE) + target_include_directories(LLVM_JIT INTERFACE ${LLVM_INCLUDE_DIRS}) + target_compile_definitions(LLVM_JIT INTERFACE ${LLVM_DEFINITIONS}) + target_link_libraries(LLVM_JIT INTERFACE ${LLVM_LIBS}) -# Add LLVM JIT related sources -list(APPEND NVFUSER_SRCS - ${NVFUSER_SRCS_DIR}/host_ir/jit.cpp -) + # Add LLVM JIT related sources + list(APPEND NVFUSER_SRCS + ${NVFUSER_SRCS_DIR}/host_ir/jit.cpp + ) +endif() # We don't link CUPTI for MSVC if(NOT MSVC) @@ -540,7 +545,9 @@ if (BUILD_CUTLASS AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) target_compile_definitions(codegen_internal PRIVATE "-DNVFUSER_CUTLASS_KERNEL_ENABLED") endif() -target_link_libraries(codegen_internal PUBLIC LLVM_JIT) +if(NVFUSER_HOST_IR_JIT) + target_link_libraries(codegen_internal PUBLIC LLVM_JIT) +endif() add_library(nvfuser_codegen SHARED $) @@ -575,7 +582,8 @@ target_include_directories(nvfuser_codegen SYSTEM PUBLIC ) target_link_libraries(nvfuser_codegen PUBLIC ${TORCH_LIBRARIES} - PRIVATE dynamic_type flatbuffers ${CUDA_NVRTC_LIB} CUDA::cupti dl LLVM_JIT + PRIVATE dynamic_type flatbuffers ${CUDA_NVRTC_LIB} CUDA::cupti dl + $<$:LLVM_JIT> ) set_target_properties(nvfuser_codegen PROPERTIES C_STANDARD ${NVFUSER_C_STANDARD} @@ -1232,13 +1240,15 @@ if(BUILD_TEST) add_test(test_host_ir "${HOSTIR_TEST_SRCS}" "") list(APPEND TEST_BINARIES test_host_ir) - set(LLVM_COMPILE_TEST_SRCS) - list(APPEND LLVM_COMPILE_TEST_SRCS - ${NVFUSER_ROOT}/tests/cpp/test_host_ir_jit.cpp - ) - add_test(test_host_ir_jit "${LLVM_COMPILE_TEST_SRCS}" "") - target_link_libraries(test_host_ir_jit PUBLIC LLVM_JIT) - list(APPEND TEST_BINARIES test_host_ir_jit) + if(NVFUSER_HOST_IR_JIT) + set(LLVM_COMPILE_TEST_SRCS) + list(APPEND LLVM_COMPILE_TEST_SRCS + ${NVFUSER_ROOT}/tests/cpp/test_host_ir_jit.cpp + ) + add_test(test_host_ir_jit "${LLVM_COMPILE_TEST_SRCS}" "") + target_link_libraries(test_host_ir_jit PUBLIC LLVM_JIT) + list(APPEND TEST_BINARIES test_host_ir_jit) + endif() # We don't link CUPTI for MSVC @@ -1477,6 +1487,7 @@ message(STATUS " UCC_FOUND: ${UCC_FOUND}") message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_UCC : ${NVFUSER_STANDALONE_BUILD_WITH_UCC}") message(STATUS " NVFUSER_BUILD_WITH_ASAN : ${NVFUSER_BUILD_WITH_ASAN}") message(STATUS " NVFUSER_DISTRIBUTED : ${NVFUSER_DISTRIBUTED}") +message(STATUS " NVFUSER_HOST_IR_JIT : ${NVFUSER_HOST_IR_JIT}") message(STATUS " NVFUSER_CPP_STANDARD : ${NVFUSER_CPP_STANDARD}") message(STATUS " NVMMH_INCLUDE_DIR : ${NVMMH_INCLUDE_DIR}") diff --git a/csrc/host_ir/jit.cpp b/csrc/host_ir/jit.cpp index 8dec5c1afe3..a4188c93587 100644 --- a/csrc/host_ir/jit.cpp +++ b/csrc/host_ir/jit.cpp @@ -492,8 +492,8 @@ void inferTensorShapesAndStrides( // Check if sizes and strides are the same size as logical domain const auto logical_ndims = std::ranges::distance(logical_domain | TensorDomain::kNoReductions); - NVF_ERROR_EQ(std::ssize(sizes), logical_ndims); - NVF_ERROR_EQ(std::ssize(strides), logical_ndims); + NVF_ERROR_EQ(sizes.size(), logical_ndims); + NVF_ERROR_EQ(strides.size(), logical_ndims); } void unpackInputs( diff --git a/csrc/options.cpp b/csrc/options.cpp index 00f81b857f7..174718af872 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -176,7 +176,6 @@ const std::unordered_map& getEnableOptions() { {"warn_register_spill", EnableOption::WarnRegisterSpill}, {"ws_normalization", EnableOption::WarpSpecializedNormalization}, {"host_ir_lowering", EnableOption::HostIrLowering}, - {"host_ir_jit", EnableOption::HostIrJit}, {"insert_resharding_after", EnableOption::InsertReshardingAfter}, {"fast_math", EnableOption::FastMath}, }; diff --git a/csrc/options.h b/csrc/options.h index 4c32143aa36..38cd508aa57 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -120,7 +120,6 @@ enum class EnableOption { WarnRegisterSpill, //! Enable warnings of register spill WarpSpecializedNormalization, //! Enable warp specialized persistent kernel HostIrLowering, //! Enable FusionKernelRuntime lowering to host IR - HostIrJit, //! Enable Host IR JIT compilation with LLVM InsertReshardingAfter, //! Insert resharding set after the expression FastMath, //! Enable fast math optimizations (--use_fast_math) EndOfOption //! Placeholder for counting the number of elements diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index 962b540d675..4041c96a2b4 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -157,7 +157,11 @@ void FusionKernelRuntime::evictCache(size_t input_id) { bool FusionKernelRuntime::isCompiled() const { if (isOptionEnabled(EnableOption::HostIrLowering)) { - return hij_ != nullptr || hie_ != nullptr; +#ifdef NVFUSER_HOST_IR_JIT + return hij_ != nullptr; +#else + return hie_ != nullptr; +#endif } else { std::lock_guard guard(mutex_); return std::all_of( @@ -295,14 +299,13 @@ KernelArgumentHolder FusionKernelRuntime::runWithInputs( << std::endl; } - KernelArgumentHolder outputs; - if (hij_ != nullptr) { - outputs = hij_->runWithInputs(args); - } else if (hie_ != nullptr) { - outputs = hie_->runWithInputs(args); - } else { - NVF_THROW("Neither Host IR JIT or Host IR Evaluator are initialized."); - } +#ifdef NVFUSER_HOST_IR_JIT + auto outputs = + hij_->runWithInputs(args); // TODO: change NVFUSER_HOST_IR_JIT flag to + // enableOption in the future. +#else + auto outputs = hie_->runWithInputs(args); +#endif if (isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) { debug() << "============= FINISHED RUNNING HOSTIR EVALUATOR ============" @@ -469,12 +472,12 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { } std::unique_ptr hic = lowerSegmentedFusionToHostIr( *segmented_fusion_, launch_params_per_segment, executors_); - if (isOptionEnabled(EnableOption::HostIrJit)) { - hij_ = std::make_unique(std::move(hic)); - } else { - hie_ = std::make_unique( - std::move(hic), &Communicator::getInstance()); - } +#ifdef NVFUSER_HOST_IR_JIT + hij_ = std::make_unique(std::move(hic)); +#else + hie_ = std::make_unique( + std::move(hic), &Communicator::getInstance()); +#endif } if (isProfilerEnabled()) { diff --git a/csrc/runtime/fusion_kernel_runtime.h b/csrc/runtime/fusion_kernel_runtime.h index 31965df07c2..6a16ee27a7f 100644 --- a/csrc/runtime/fusion_kernel_runtime.h +++ b/csrc/runtime/fusion_kernel_runtime.h @@ -11,7 +11,9 @@ #include #include +#ifdef NVFUSER_HOST_IR_JIT #include +#endif #include #include #include @@ -141,13 +143,11 @@ class FusionKernelRuntime { //! Get the Host IR Container const hir::HostIrContainer& getHostIrContainer() const { - if (isOptionEnabled(EnableOption::HostIrJit)) { - NVF_ERROR(hij_ != nullptr, "Host IR JIT is not initialized"); - return hij_->container(); - } else { - NVF_ERROR(hie_ != nullptr, "Host IR Evaluator is not initialized"); - return hie_->container(); - } +#ifdef NVFUSER_HOST_IR_JIT + return hij_->container(); +#else + return hie_->container(); +#endif } private: @@ -189,10 +189,13 @@ class FusionKernelRuntime { //! Executors holding compiled kernels std::vector> executors_; - //! Host IR JIT (used when EnableOption::HostIrJit is set) +#ifdef NVFUSER_HOST_IR_JIT + //! Host IR JIT std::unique_ptr hij_; - //! Host IR Evaluator (used when EnableOption::HostIrJit is not set) +#else + //! Host IR Evaluator std::unique_ptr hie_; +#endif // A metadata copy of initial arguments used to contruct this // FusionKernelRuntime. Used during deserialization to schedule the fusion diff --git a/csrc/scheduler/greedy.cpp b/csrc/scheduler/greedy.cpp index 7dffc7dab5e..02cb1d88d56 100644 --- a/csrc/scheduler/greedy.cpp +++ b/csrc/scheduler/greedy.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -634,6 +635,8 @@ class RunTimeChecker : private IterVisitor { max_threads_per_block_( at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock) { traverse(fusion); + + checkSharedMemoryBufferUsage(); } void dispatch(Expr* expr) override { @@ -644,46 +647,88 @@ class RunTimeChecker : private IterVisitor { } void handle(ArgsortOp* argsort) override { - checkDomainConstraints( + int64_t size_of_constrained_ids = checkDomainConstraints( ir_utils::getTvOutput(argsort)->getLogicalDomain(), {argsort->dim()}, + dataTypeSizeByte(ir_utils::getTvOutput(argsort)->dtype()), /*support_batching=*/true); + + int64_t batch_size = + ceilDiv(size_of_constrained_ids, max_threads_per_block_); + int64_t bdimx = std::min(size_of_constrained_ids, max_threads_per_block_); + cub_shmem_buffer_.registerArgsort( + bdimx, batch_size, ir_utils::getTvInput(argsort)->dtype()); } void handle(PadOp* pad) override { checkDomainConstraints( - ir_utils::getTvOutput(pad)->getLogicalDomain(), pad->getPaddedAxes()); + ir_utils::getTvOutput(pad)->getLogicalDomain(), + pad->getPaddedAxes(), + dataTypeSizeByte(ir_utils::getTvOutput(pad)->dtype())); } void handle(ScanOp* scan) override { checkDomainConstraints( ir_utils::getTvOutput(scan)->getLogicalDomain(), {scan->dim()}, + dataTypeSizeByte(ir_utils::getTvOutput(scan)->dtype()), /*support_batching=*/true); } void handle(TopKOp* topk) override { - checkDomainConstraints( - ir_utils::getTvOutput(topk)->getLogicalDomain(), {topk->dim()}); + // TopKOp produces two outputs: one has the same type as the input + // and another is an integer index tensor + int64_t size_of_constrained_ids = checkDomainConstraints( + TensorDomain::noReductions( + ir_utils::getTvInput(topk)->getLogicalDomain()), + {topk->dim()}, + dataTypeSizeByte(ir_utils::getTvInput(topk)->dtype()) + + dataTypeSizeByte(DataType::Int)); + + int64_t batch_size = + ceilDiv(size_of_constrained_ids, max_threads_per_block_); + int64_t bdimx = std::min(size_of_constrained_ids, max_threads_per_block_); + cub_shmem_buffer_.registerTopK( + bdimx, batch_size, ir_utils::getTvInput(topk)->dtype()); } void handle(ScatterOp* scatter) override { auto out = ir_utils::getTvOutput(scatter); auto index = scatter->index()->as(); + // TODO: If the input and output is a fusion input and output, + // there will be no computation for the shape of the logical + // domain, so this check is not necessary. checkDomainConstraints( out->getLogicalDomain(), {scatter->dim()}, + dataTypeSizeByte(out->dtype()), /*support_batching=*/true); + + int64_t index_bytes = dataTypeSizeByte(index->dtype()); + // If it's scalar, ignore the contribution + int64_t src_bytes = scatter->src()->isA() + ? dataTypeSizeByte(scatter->src()->dtype()) + : 0; + checkDomainConstraints( TensorDomain::noReductions(index->getLogicalDomain()), {scatter->dim()}, + index_bytes + src_bytes, /*support_batching=*/true); } - void checkDomainConstraints( + // Check the constraints on the given domain. bytes_per_element + // indicates the size of data required to hold one work item, which + // may correspond to multiple tensor elements. For example, in the + // case of TopKOp, two outputs are produced, so the size should + // cover both of them. + // + // Returns the size of the constrained IDs in bytes + int64_t checkDomainConstraints( const std::vector& domain, const std::vector& constrained_id_offsets, + int64_t bytes_per_element, bool support_batching = false) { int64_t size_of_constrained_ids = 1; for (const auto i : constrained_id_offsets) { @@ -697,42 +742,77 @@ class RunTimeChecker : private IterVisitor { size_of_constrained_ids *= extent_val.as(); } - // The maximum supported size depends on several factors. The hard - // limit is the shared memory capacity since the kernel launch - // would just fail if the shared memory usage exceeds the - // available size. The next important limit would be the register - // usage as we would not want to have excessive register spilling. - // + const int64_t threads_per_block = max_threads_per_block_; + // At this moment, not all constrained ops supports batching. If // batching is not supported, the limit is simply set as the // maximum number of threads per thread block. This is likely // a sufficient condition even for shared memory, although not // guaranteed. + if (!support_batching) { + if (size_of_constrained_ids > threads_per_block) { + reject( + "Extent of constrained logical IDs, ", + size_of_constrained_ids, + ", exceeds the number of threads per thread block: ", + threads_per_block); + } + } + + // The maximum supported size depends on several factors. The hard + // limit is the shared memory capacity since the kernel launch + // would just fail if the shared memory usage exceeds the + // available size. It is checked at the end of the RunTimeChecker + // constructor. // - // When batching is supported, up to half of the shared memory - // capacity is allowed for now. This is a pretty rough estimate - // and does not guarantee the safety of kernel launches nor avoids - // register spilling but is used for now since more accurate - // estimation of shared memory usage remains to be done, and the - // register spilling is not a functional concern. - // - // TODO: More accurate estimation of resource requirements - int64_t max_supported_size = max_threads_per_block_; - if (support_batching) { - auto available_shmem_capacity = - at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock / 2; - // TODO: don't assume it's always float. - auto element_size = sizeof(float); - max_supported_size = - static_cast(available_shmem_capacity / element_size); - } - - if (size_of_constrained_ids > max_supported_size) { + // The next important limit would be the register usage as we + // would not want to have excessive register spilling. The + // register usage would be linearly correlated with the batching + // factor. For now, just put a simple upper limit to avoid + // disastrous regressions. Fine tuning would be necessary. + const int64_t register_count_per_thread = + ceilDiv(size_of_constrained_ids, threads_per_block) * + bytes_per_element / 4; + const int64_t available_register_count_per_thread = + at::cuda::getCurrentDeviceProperties()->regsPerBlock / + threads_per_block; + // Make sure at least 20 registers are always available + const int64_t reserved_register_count_per_thread = 20; + if (register_count_per_thread + reserved_register_count_per_thread > + available_register_count_per_thread) { + reject( + "Expected register usage, ", + register_count_per_thread, + ", exceeds the available count, ", + available_register_count_per_thread); + } + + return size_of_constrained_ids; + } + + void checkSharedMemoryBufferUsage() { + // TODO: Use the constant and util functions added in #5272 + auto aligned_size = [](int64_t x) { return (x + 127) / 128 * 128; }; + + const int64_t cub_buffer_size = + aligned_size(cub_shmem_buffer_.getTotalSizeInBytes()); + + // TODO: Shared memory may be also used for resolving mismatched + // parallelization of constrained. + + const auto total_required_size = cub_buffer_size; + + const auto available_size = static_cast( + at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock); + + if (total_required_size > available_size) { reject( - "Extent of constrained logical IDs, ", - size_of_constrained_ids, - ", exceeds the maxinum supported size: ", - max_supported_size); + "Not enough shared memory. Required size for CUB: ", + cub_buffer_size, + ". Total required size: ", + total_required_size, + ". Available: ", + available_size); } } @@ -750,6 +830,7 @@ class RunTimeChecker : private IterVisitor { private: SchedulerRuntimeInfo& runtime_info_; int64_t max_threads_per_block_ = 0; + scheduler_tools::CubSharedMemoryBuffer cub_shmem_buffer_; bool can_schedule_ = true; std::string reject_reason_; diff --git a/csrc/scheduler/tools/cub_utils.cpp b/csrc/scheduler/tools/cub_utils.cpp new file mode 100644 index 00000000000..fd618cc9d3d --- /dev/null +++ b/csrc/scheduler/tools/cub_utils.cpp @@ -0,0 +1,375 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include + +namespace nvfuser { +namespace scheduler_tools { + +namespace { + +constexpr int64_t alignUp(int64_t value, int64_t alignment) { + return (value + (alignment - 1)) & ~(alignment - 1); +} + +constexpr bool isPowerOfTwo(int64_t x) { + return x != 0 && (x & (x - 1)) == 0; +} + +constexpr int64_t ceilLog2(int64_t x) { + // returns ceil(log2(x)) for x >= 1 + int64_t n = 0; + int64_t v = 1; + while (v < x) { + v <<= 1; + ++n; + } + return n; +} + +/* + CUB BlockRadixSort shared memory usage. The comment and the code are + generated by Cursor with GPT-5. + + Assumptions: + - RADIX_BITS = 4 (RADIX_DIGITS = 16) + - SMEM bank size = 4 bytes + - INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS + - align32(x) rounds x up to the next multiple of 32 bytes: ((x + 31) & ~31) + + Composition: + - BlockRadixSort::TempStorage is a union of: + AscendingBlockRadixRank::TempStorage, + DescendingBlockRadixRank::TempStorage, + BlockExchange::TempStorage, + BlockExchange::TempStorage. + - Total bytes: S(BlockRadixSort) = max(S_rank, S_exch(KeyT), S_exch(ValueT)). + + Rank storage (independent of KeyT/ValueT): + - Let B = BLOCK_DIM_X x BLOCK_DIM_Y x BLOCK_DIM_Z, W = ceil(B / 32). + - PackedCounter = unsigned int (4 bytes) for 4-byte SMEM banks. + - Aliasable part (digit counters or raking grid): S_rank_aliasable = 36 * B +bytes // RADIX_BITS=4 => PADDED_COUNTER_LANES=9 + - BlockScan contribution: + * If (B % 32 == 0) // warp-scans path + S_blockscan ~= align32(5 * W + 4) bytes + (W warp aggregates @4B, plus W bytes from array of empty per-warp +TempStorage, plus 4B block prefix) + * Else // raking fallback + S_blockscan is larger (~1 KB near B~200). Exact size depends on +BlockRakingLayout and alignment. + - Total: S_rank = S_rank_aliasable + S_blockscan. + - Final TempStorage size = align16(max(S_rank, S_exch(KeyT), S_exch(ValueT))) + because the union inherits 16-byte alignment from +BlockExchange::_TempStorage. + +Exchange storage (depends on KeyT/ValueT): + - TILE_ITEMS = B * ITEMS_PER_THREAD + - PADDING_ITEMS = (ITEMS_PER_THREAD > 4 && power_of_two(ITEMS_PER_THREAD)) ? +(TILE_ITEMS >> 5) : 0 // 32 banks +- S_exch(T) = align16(sizeof(T) * (TILE_ITEMS + PADDING_ITEMS)) // +BlockExchange::_TempStorage is alignas(16) + + Worked examples (KeyT = int64_t, ValueT = int64_t, BLOCK_DIM_Y = BLOCK_DIM_Z = +1, RADIX_BITS = 4, SMEM banks = 4B): + +ITEMS_PER_THREAD = 1 +- B = 128 (W = 4): S_rank = 36*128 + align32(5*4 + 4) = 4608 + 32 = 4640 + S_exch(key)=8*128=1024, values=1024 -> S = 4640 +- B = 224 (W = 7): S_rank = 36*224 + align32(5*7 + 4) = 8064 + 64 = 8128 + S = 8096 +- B = 256 (W = 8): S_rank = 36*256 + align32(5*8 + 4) = 9216 + 64 = 9280 + S = 9280 + - B = 200 (W = 7, raking): S_rank = 36*200 + 1024 = 8224 (exact) + S = 8224 + + ITEMS_PER_THREAD = 4 (no padding) + - B = 128: S = 4640 +- B = 224: S = 8128 + - B = 256: S = 9280 + - B = 200: S = 8224 + + ITEMS_PER_THREAD = 8 (padding applies; PADDING_ITEMS = (B*8)/32 = B/4) + - B = 128: S_rank = 4640; S_exch(T) = 8*(128*8 + 32) = 8448 -> S = 8448 + - B = 224: S_rank = 8128; S_exch(T) = 8*(224*8 + 56) = 14784 -> S = 14784 + - B = 256: S_rank = 9280; S_exch(T) = 8*(256*8 + 64) = 16896 -> S = 16896 + - B = 200: S_rank = 8224; S_exch(T) = 8*(200*8 + 50) = 13200 -> S = 13200 + + Rule of thumb: + - Rank aliasable scales as 36 * B; BlockScan adds a small aligned term if +B%32==0, else ~1 KB (varies with B). + - Exchange per type scales as sizeof(T) * B * ITEMS_PER_THREAD (padding kicks +in when ITEMS_PER_THREAD > 4 and pow2). + - For 64-bit key + 64-bit value, exchange overtakes rank around +ITEMS_PER_THREAD >= 5 (ignoring small/aligned term). + + sizeof() sanity check snippet: + using BRS = cub::BlockRadixSort; printf("%zu\n", sizeof(typename BRS::TempStorage)); +*/ +constexpr int64_t computeBlockRadixSortTempStorageBytes( + int64_t block_threads, + int64_t items_per_thread, + int64_t key_size_bytes, + int64_t value_size_bytes) { + // Rank aliasable part (RADIX_BITS=4, 4B banks → PADDED_COUNTER_LANES=9 → 36 + // bytes per thread) + const int64_t rank_aliasable_bytes = int64_t{36} * block_threads; + + // BlockScan contribution inside BlockRadixRank + const bool multiples_of_32 = (block_threads % int64_t{32}) == 0; + const int64_t warps = ceilDiv(block_threads, int64_t{32}); + + int64_t blockscan_bytes = 0; + if (multiples_of_32) { + // Warp-scans path: W warp aggregates (4B each) + W bytes for array of empty + // per-warp TempStorage + one block prefix (4B), aligned to 32 bytes + const int64_t raw = int64_t{5} * warps + int64_t{4}; + blockscan_bytes = alignUp(raw, int64_t{32}); + } else { + // Raking path: compute exact size using BlockRakingLayout and WarpScanSmem + // for PackedCounter (4B) + const int64_t max_raking_threads = + block_threads < int64_t{32} ? block_threads : int64_t{32}; + const int64_t segment_length = ceilDiv(block_threads, max_raking_threads); + const bool use_segment_padding = + ((segment_length & int64_t{1}) == 0) && (segment_length > int64_t{2}); + const int64_t raking_threads = ceilDiv(block_threads, segment_length); + + // WarpScan storage: shfl variant for power-of-two raking_threads (empty), + // smem variant otherwise (1.5 * raking_threads elements of 4B) + int64_t warp_scan_size = 0; + if (isPowerOfTwo(raking_threads)) { + // Empty TempStorage still contributes 1 byte as a member + warp_scan_size = 1; + } else { + const int64_t steps = ceilLog2(raking_threads); + const int64_t half_warp_threads = + (steps == 0) ? int64_t{0} : (int64_t{1} << (steps - 1)); + const int64_t warp_smem_elements = raking_threads + half_warp_threads; + warp_scan_size = + int64_t{4} * warp_smem_elements; // sizeof(PackedCounter)=4 + } + + // BlockRakingLayout grid + const int64_t grid_elements = raking_threads * + (segment_length + (use_segment_padding ? int64_t{1} : int64_t{0})); + const int64_t raking_grid_bytes = + int64_t{4} * grid_elements; // sizeof(PackedCounter)=4 + + // Layout with alignments: warp_scan (4B aligned) -> pad to 16B -> + // raking_grid (align 16) -> block_aggregate (4B) + int64_t total = 0; + total += warp_scan_size; + total = alignUp(total, int64_t{16}); + total += alignUp(raking_grid_bytes, int64_t{16}); + total += int64_t{4}; // block_aggregate of PackedCounter + total = alignUp(total, int64_t{16}); // struct alignment + + blockscan_bytes = total; + } + + const int64_t rank_bytes = rank_aliasable_bytes + blockscan_bytes; + + // Exchange storage for key/value types + const int64_t tile_items = block_threads * items_per_thread; + const bool needs_padding = + (items_per_thread > int64_t{4}) && isPowerOfTwo(items_per_thread); + const int64_t padding_items = needs_padding + ? (tile_items >> int64_t{5}) + : int64_t{0}; // 32 banks → LOG_SMEM_BANKS=5 + const int64_t exchange_keys = key_size_bytes * (tile_items + padding_items); + const int64_t exchange_values = + value_size_bytes * (tile_items + padding_items); + int64_t exchange_bytes = + exchange_keys > exchange_values ? exchange_keys : exchange_values; + // BlockExchange::_TempStorage is alignas(16), so size is rounded up to + // 16-byte multiple + exchange_bytes = alignUp(exchange_bytes, int64_t{16}); + + // Exchange buffers are in a union; contribution is the larger of the two. For + // keys-only, set value_size_bytes=0. + int64_t result = rank_bytes > exchange_bytes ? rank_bytes : exchange_bytes; + // The union holding rank/exchange storage inherits max alignment (16 via + // BlockExchange), so the final size rounds up to 16 bytes. + result = alignUp(result, int64_t{16}); + return result; +} + +/* + CUB BlockScan shared memory usage. The comment and the code are + generated by Cursor with GPT-5. + + Assumptions: + - SMEM banks = 4 bytes; align32/align16 apply as below + - SAFE_ALGORITHM: if warp-scans is requested but B%32 != 0, fallback to raking + - B = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, W = ceil(B/32) + + Warp-scans (B % 32 == 0): + - TempStorage layout (alignas(32)): + T warp_aggregates[W]; + WarpScan::TempStorage warp_scan[W]; // empty type -> contributes W + bytes total T block_prefix; + - Size: S = align32(W * (sizeof(T) + 1) + sizeof(T)) + + Raking (fallback or forced): + - Raking layout: + max_raking_threads = min(B, 32) + segment_length = ceil(B / max_raking_threads) + use_segment_padding = ((segment_length & 1) == 0) && (segment_length > 2) + raking_threads = ceil(B / segment_length) + - WarpScan::TempStorage: + * if raking_threads is power-of-two => SHFL path, empty -> contributes 1 + byte + * else SMEM path => elements = raking_threads + + 2^(ceil_log2(raking_threads)-1) bytes = elements * sizeof(T) + - BlockRakingLayout grid bytes: + GRID = raking_threads * (segment_length + (use_segment_padding ? 1 : 0)) + bytes = GRID * sizeof(T) + - TempStorage size (alignas(16)): + S = align16( warp_scan_bytes ) + + align16( grid_bytes ) + + sizeof(T) + => align16(total) + + Notes: + - For B=32 (warp-scans), W=1 and S = align32((sizeof(T)+1) + sizeof(T)) = 32 + for common T sizes (e.g., 4B, 8B). + - For B=16, SAFE fallback uses raking; raking_threads=16 (power-of-two) so + SHFL-only WarpScan contributes 1 byte. Typical sizes: T=4 => 96 bytes; T=8 => + 160 bytes. + + sizeof() sanity check snippet: + using BS = cub::BlockScan; + printf("%zu\n", sizeof(typename BS::TempStorage)); +*/ +constexpr int64_t computeBlockScanTempStorageBytes( + int64_t block_threads, + int64_t type_size_bytes, + bool use_warp_scans) { + const int64_t warps = ceilDiv(block_threads, int64_t{32}); + + // Warp-scans path + if (use_warp_scans && (block_threads % int64_t{32} == 0)) { + // alignas(32) struct with W warp aggregates (T), W empty TempStorages (1B + // each), and block_prefix (T) + const int64_t raw = + warps * (type_size_bytes + int64_t{1}) + type_size_bytes; + return alignUp(raw, int64_t{32}); + } + + // Raking path (BlockScanRaking) + const int64_t max_raking_threads = + block_threads < int64_t{32} ? block_threads : int64_t{32}; + const int64_t segment_length = ceilDiv(block_threads, max_raking_threads); + const bool use_segment_padding = + ((segment_length & int64_t{1}) == 0) && (segment_length > int64_t{2}); + const int64_t raking_threads = ceilDiv(block_threads, segment_length); + + // WarpScan storage for logical warp = raking_threads + int64_t warp_scan_size = 0; + if (isPowerOfTwo(raking_threads)) { + // Empty TempStorage still contributes 1 byte as a member + warp_scan_size = 1; + } else { + const int64_t steps = ceilLog2(raking_threads); + const int64_t half_warp_threads = + (steps == 0) ? int64_t{0} : (int64_t{1} << (steps - 1)); + const int64_t warp_smem_elements = raking_threads + half_warp_threads; + warp_scan_size = type_size_bytes * warp_smem_elements; + } + + // BlockRakingLayout grid + const int64_t grid_elements = raking_threads * + (segment_length + (use_segment_padding ? int64_t{1} : int64_t{0})); + const int64_t raking_grid_bytes = type_size_bytes * grid_elements; + + // Layout: warp_scan (4B aligned) -> pad to 16 -> raking_grid (align16) -> + // block_aggregate (T) -> pad to 16 + int64_t total = 0; + total += warp_scan_size; + total = alignUp(total, int64_t{16}); + total += alignUp(raking_grid_bytes, int64_t{16}); + total += type_size_bytes; + total = alignUp(total, int64_t{16}); + + return total; +} + +} // namespace + +void CubSharedMemoryBuffer::registerArgsort( + int64_t bdimx, + int64_t items_per_thread, + DataType dtype) { + max_bdimx_ = std::max(max_bdimx_, bdimx); + argsort_calls_.emplace(items_per_thread, dtype); +} + +void CubSharedMemoryBuffer::registerScan( + int64_t bdimx, + int64_t items_per_thread, + DataType dtype) { + max_bdimx_ = std::max(max_bdimx_, bdimx); + scan_calls_.emplace_back(dtype); +} + +void CubSharedMemoryBuffer::registerTopK( + int64_t bdimx, + int64_t items_per_thread, + DataType dtype) { + max_bdimx_ = std::max(max_bdimx_, bdimx); + topk_calls_.emplace(items_per_thread, dtype); +} + +int64_t CubSharedMemoryBuffer::getArgsortTotalSizeInBytes() const { + int64_t total_size = 0; + for (const auto& template_instance : argsort_calls_) { + total_size += computeBlockRadixSortTempStorageBytes( + max_bdimx_, + template_instance.items_per_thread, + dataTypeSizeByte(template_instance.dtype), + sizeof(int64_t)); + } + + return total_size; +} + +int64_t CubSharedMemoryBuffer::getTopKTotalSizeInBytes() const { + int64_t total_size = 0; + for (const auto& template_instance : topk_calls_) { + total_size += computeBlockRadixSortTempStorageBytes( + max_bdimx_, + template_instance.items_per_thread, + dataTypeSizeByte(template_instance.dtype), + sizeof(int64_t)); + } + + return total_size; +} + +int64_t CubSharedMemoryBuffer::getScanTotalSizeInBytes() const { + int64_t total_size = 0; + for (const auto& template_instance : scan_calls_) { + total_size += computeBlockScanTempStorageBytes( + max_bdimx_, + dataTypeSizeByte(template_instance.dtype), + /*use_warp_scans=*/true); + } + + return total_size; +} + +int64_t CubSharedMemoryBuffer::getTotalSizeInBytes() const { + return getArgsortTotalSizeInBytes() + getScanTotalSizeInBytes() + + getTopKTotalSizeInBytes(); +} + +} // namespace scheduler_tools +} // namespace nvfuser diff --git a/csrc/scheduler/tools/cub_utils.h b/csrc/scheduler/tools/cub_utils.h new file mode 100644 index 00000000000..a7f91e89641 --- /dev/null +++ b/csrc/scheduler/tools/cub_utils.h @@ -0,0 +1,95 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +namespace nvfuser { +namespace scheduler_tools { + +// Utility class to compute the size of the shared memory buffer used +// by CUB for operations like argsort +class CubSharedMemoryBuffer { + public: + void registerArgsort(int64_t bdimx, int64_t items_per_thread, DataType dtype); + + void registerScan(int64_t bdimx, int64_t items_per_thread, DataType dtype); + + void registerTopK(int64_t bdimx, int64_t items_per_thread, DataType dtype); + + int64_t getTotalSizeInBytes() const; + + int64_t getArgsortTotalSizeInBytes() const; + + int64_t getScanTotalSizeInBytes() const; + + int64_t getTopKTotalSizeInBytes() const; + + private: + // Parameters affecting the buffer size of each call using block + // radix sort. bdimx is common across all calls, so not included + // here. + struct BlockRadixSortParameters { + int64_t items_per_thread; + DataType dtype; + + bool operator==(const BlockRadixSortParameters& other) const { + return items_per_thread == other.items_per_thread && dtype == other.dtype; + } + }; + + struct BlockRadixSortParametersHash { + std::size_t operator()(const BlockRadixSortParameters& key) const { + return std::hash()(key.items_per_thread); + } + }; + + // Parameters affecting the buffer size of each call using block + // scan. bdimx is common across all calls, so not included + // here. + struct BlockScanParameters { + DataType dtype; + + bool operator==(const BlockScanParameters& other) const { + return dtype == other.dtype; + } + }; + + struct BlockScanParametersHash { + std::size_t operator()(const BlockScanParameters& key) const { + if (auto prim_type = std::get_if(&key.dtype.type)) { + return static_cast(*prim_type) + 1; + } else { + return 0; + } + } + }; + + private: + int64_t max_bdimx_ = -1; + + // Keep track of argsort calls to compute the total size of the + // shared memory buffers used for argsort + std::unordered_set + argsort_calls_; + + // Keep track of scan calls to compute the total size of the + // shared memory buffers used for scan. Note that each call seems to + // be considered a distinctive separate call due to the lambda + // parameter, and thus there's no reuse even for the same data + // type. This should be fixed by using dynamically allocated buffers. + std::vector scan_calls_; + + // Keep track of topk calls to compute the total size of the + // shared memory buffers used for topk. + std::unordered_set + topk_calls_; +}; + +} // namespace scheduler_tools +} // namespace nvfuser diff --git a/doc/dev/host_ir_jit.md b/doc/dev/host_ir_jit.md index 4a10292f333..833536342cd 100644 --- a/doc/dev/host_ir_jit.md +++ b/doc/dev/host_ir_jit.md @@ -98,18 +98,21 @@ KernelArgumentHolder HostIrJitImpl::runWithInputs(const KernelArgumentHolder& ar ``` *Detailed Implementation:* https://github.com/NVIDIA/Fuser/blob/3ac1a4697b6b5c31e4dbb9763b3b6db2f0e0164b/csrc/host_ir/jit.cpp#L1399-L1453 -## Configuration and Runtime Options - -### Build Requirements -**LLVM 18.1+ is required** to build nvFuser. You can switch between Host IR JIT and Host IR Evaluator at runtime. - -### Runtime Configuration -You can enable Host IR JIT via runtime option `EnableOption::HostIrJit` or environment `NVFUSER_ENABLE="host_ir_jit"`. - -When `host_ir_jit` is enabled, the runtime uses LLVM ORC JIT for low-latency host execution. When disabled, it falls back to the Host IR Evaluator. - +## Configuration and Build Options +Building nvFuser project with `NVFUSER_BUILD_HOST_IR_JIT=1` will enables Host IR JIT as default runtime in Host IR execution path. +Otherwise the default runtime is Host IR Evaluator. In the future, when llvm is fully supported in all build machines, we are able +to get rid of this opt-in flag and rather use `enableOption` to control backend switching after build is done. + +Sample build +```python +NVFUSER_BUILD_HOST_IR_JIT=1 pip install --no-build-isolation -e python -v +``` +or +```python +NVFUSER_BUILD_HOST_IR_JIT=1 _bn +``` ## Future Integration plan -We plan to turn on host IR JIT by default after its functionality and performance are on par. +We plan to turn on host IR JIT by default after its function and performance are on par. Known missing supports and bugs are: **Ops need to be supported:** diff --git a/python/utils.py b/python/utils.py index 98cb219f052..b1caa9618dc 100644 --- a/python/utils.py +++ b/python/utils.py @@ -25,6 +25,7 @@ class BuildConfig: build_with_asan: bool = False build_without_distributed: bool = False explicit_error_check: bool = False + build_with_host_ir_jit: bool = False overwrite_version: bool = False version_tag: str = None build_type: str = "Release" @@ -104,6 +105,12 @@ def parse_args(): action="store_true", help="Build nvfuser with UCC support", ) + parser.add_argument( + "--build-with-host-ir-jit", + dest="build_with_host_ir_jit", + action="store_true", + help="Build nvfuser with Host IR JIT support", + ) parser.add_argument( "--explicit-error-check", dest="explicit_error_check", @@ -206,6 +213,7 @@ def create_build_config(): no_benchmark=args.no_benchmark, no_ninja=args.no_ninja, build_with_ucc=args.build_with_ucc, + build_with_host_ir_jit=args.build_with_host_ir_jit, build_with_asan=args.build_with_asan, build_without_distributed=args.build_without_distributed, explicit_error_check=args.explicit_error_check, @@ -251,6 +259,8 @@ def override_build_config_from_env(config): config.no_ninja = get_env_flag_bool("NVFUSER_BUILD_NO_NINJA") if "NVFUSER_BUILD_WITH_UCC" in os.environ: config.build_with_ucc = get_env_flag_bool("NVFUSER_BUILD_WITH_UCC") + if "NVFUSER_BUILD_HOST_IR_JIT" in os.environ: + config.build_with_host_ir_jit = get_env_flag_bool("NVFUSER_BUILD_HOST_IR_JIT") if "NVFUSER_BUILD_WITH_ASAN" in os.environ: config.build_with_asan = get_env_flag_bool("NVFUSER_BUILD_WITH_ASAN") if "NVFUSER_BUILD_WITHOUT_DISTRIBUTED" in os.environ: @@ -481,6 +491,7 @@ def on_or_off(flag: bool) -> str: f"-DPython_EXECUTABLE={sys.executable}", f"-DBUILD_NVFUSER_BENCHMARK={on_or_off(not config.no_benchmark)}", f"-DNVFUSER_DISTRIBUTED={on_or_off(not config.build_without_distributed)}", + f"-DUSE_HOST_IR_JIT={on_or_off(config.build_with_host_ir_jit)}", f"-DCUTLASS_MAX_JOBS={config.cutlass_max_jobs}", f"-DNVMMH_INCLUDE_DIR={config.nvmmh_include_dir}", "-B", diff --git a/runtime/scan.cu b/runtime/scan.cu index 5fd3b03bc83..246e167fc13 100644 --- a/runtime/scan.cu +++ b/runtime/scan.cu @@ -84,11 +84,14 @@ __device__ void blockScan( // CUB BlockScan setup - with proper multi-dimensional block support // CUB BlockScan template parameters are simpler than BlockRadixSort: - // - Key type, Block dimensions, Items per thread, Algorithm (optional) + // - Key type, Block dimensions, Items per thread, Algorithm + // (optional). BLOCK_SCAN_WARP_SCANS is not the default option but + // is chosen here for now as it may work well for inference + // workloads. More detailed evaluation should be done. using BlockScan = cub::BlockScan< typename cub_utils::CudaType::type, // Data type BLOCK_DIM_X, // X dimension - cub::BLOCK_SCAN_RAKING, // Algorithm (default for BlockScan) + cub::BLOCK_SCAN_WARP_SCANS, // Algorithm BLOCK_DIM_Y, // Y dimension BLOCK_DIM_Z // Z dimension >; diff --git a/tests/cpp/test_argsort.cpp b/tests/cpp/test_argsort.cpp index e8d6ff192e1..ab9b045bbdb 100644 --- a/tests/cpp/test_argsort.cpp +++ b/tests/cpp/test_argsort.cpp @@ -13,7 +13,10 @@ #include #include #include +#include +#include #include +#include #include #include @@ -305,4 +308,124 @@ TEST_F(ArgsortTest, BufferSync) { testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); } +class ArgsortParameterizedWithBlockAndBatch + : public ArgsortTest, + public ::testing::WithParamInterface> {}; + +TEST_P(ArgsortParameterizedWithBlockAndBatch, SharedMemoryRequirement) { + DisableOptionsGuard disable_options_guard; + // Avoid using magic zero to make the estimation simpler + DisableOptionsGuard::getCurOptions().set(DisableOption::MagicZero); + // Avoid insertion of segmenter_set + preseg_passes::OptimizationPassGuard + optimization_guard(false); + + const auto [size, batch, has_duplicate, has_extra] = GetParam(); + + // This combination is not considered as the number of threads + // exceeds the limit + if (ceilDiv(size, batch) > 1024) { + return; + } + + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + DataType dtype = DataType::Int; + DataType dtype_extra = DataType::Float; + + std::vector shape = {size}; + + auto tv0 = makeContigConcreteTensor(shape, dtype); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = argsort(tv1, 0); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + // Duplicate the above call but should not change the usage as it's + // the same template instantiation + if (has_duplicate) { + auto tv4 = set(tv0); + auto tv5 = argsort(tv4, 0); + auto tv6 = set(tv5); + fusion.addOutput(tv6); + } + + // Create a different instantiation + if (has_extra) { + auto tv7 = castOp(dtype_extra, tv0); + auto tv8 = argsort(tv7, 0); + auto tv9 = set(tv8); + fusion.addOutput(tv9); + } + + for (auto tv : fusion.allTvs()) { + if (batch > 1) { + tv->split(-1, batch); + if (tv->isDefinitionType()) { + tv->axis(-1)->parallelize(ParallelType::Group); + } + } + tv->axis(0)->parallelize(ParallelType::TIDx); + } + + auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::Tensor t0 = at::randint(0, shape[0], shape, options); + + scheduler_tools::CubSharedMemoryBuffer smem_buffer; + smem_buffer.registerArgsort(ceilDiv(size, batch), batch, dtype); + // The duplicate should not increase the buffer usage + if (has_extra) { + smem_buffer.registerArgsort(ceilDiv(size, batch), batch, dtype_extra); + } + const int64_t expected_size = smem_buffer.getTotalSizeInBytes(); + + const int64_t available_capacity = + at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock; + const int64_t opt_in_available_capacity = + at::cuda::getCurrentDeviceProperties()->sharedMemPerBlockOptin; + + KernelExecutor ke; + if (expected_size <= available_capacity) { + ke.compile(&fusion, {t0}); + auto outputs = ke.run({t0}); + testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); + + // The test would fail if the estimate is not 100% accurate. That + // may be too strict and fragile as a test. After all, we would + // just need a reasonably tight upper bound. Consider relaxing the + // condition if necessary. + EXPECT_EQ(expected_size, ke.getStaticSmemSize()) + << "Actual static shared memory size was different"; + } else if (expected_size > opt_in_available_capacity) { + // Compilation should fail + EXPECT_THAT( + [&]() { ke.compile(&fusion, {t0}); }, + testing::Throws()); + } else { + // It doesn't seem consistent whether compilation or launch should + // fail if the requirement of static shared memory exceeds the default + // limit but within the opt-in larger limit. As we should move to + // dynamic allocations anyway, don't assert for now. + } +}; + +INSTANTIATE_TEST_SUITE_P( + , + ArgsortParameterizedWithBlockAndBatch, + testing::Combine( + testing::Values(128, 512, 1024, 2048, 4096), + testing::Values(1, 2, 3, 8), + testing::Bool(), + testing::Bool()), + [](const auto& info) { + std::ostringstream os; + os << std::get<0>(info.param) << "_" << std::get<1>(info.param) << "_" + << std::get<2>(info.param) << "_" << std::get<3>(info.param); + return os.str(); + }); + } // namespace nvfuser diff --git a/tests/cpp/test_greedy.cpp b/tests/cpp/test_greedy.cpp index cdcbedb582d..13a20e43480 100644 --- a/tests/cpp/test_greedy.cpp +++ b/tests/cpp/test_greedy.cpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include #include #include #include @@ -912,4 +914,201 @@ INSTANTIATE_TEST_SUITE_P( return os.str(); }); +class GreedySchedulerTestShmemSize : public GreedySchedulerTest, + public ::testing::WithParamInterface { +}; + +// Simplified version of +// ArgsortParameterizedWithBlockandBatch.SharedMemoryRequirement. The +// test may be segmented but should not fail as long as the +// expectation of the shared memory usage is accurate. +TEST_P(GreedySchedulerTestShmemSize, Argsort) { + DisableOptionsGuard disable_options_guard; + DisableOptionsGuard::getCurOptions().set(DisableOption::MagicZero); + preseg_passes::OptimizationPassGuard + optimization_guard(false); + + const auto size = GetParam(); + + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + DataType dtype = DataType::Int; + DataType dtype_extra = DataType::Float; + + std::vector shape = {size}; + + auto tv0 = makeContigConcreteTensor(shape, dtype); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = argsort(tv1, 0); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + // Duplicate the above call but should not change the usage as it's + // the same template instantiation + auto tv4 = set(tv0); + auto tv5 = argsort(tv4, 0); + auto tv6 = set(tv5); + fusion.addOutput(tv6); + + // Create a different instantiation + auto tv7 = castOp(dtype_extra, tv0); + auto tv8 = argsort(tv7, 0); + auto tv9 = set(tv8); + fusion.addOutput(tv9); + + auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::Tensor t0 = at::randint(0, shape[0], shape, options); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0}); + testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); +} + +// Simplified version of +// TopKParameterizedWithBlockandBatch.SharedMemoryRequirement. The +// test may be segmented but should not fail as long as the +// expectation of the shared memory usage is accurate. +TEST_P(GreedySchedulerTestShmemSize, TopK) { + DisableOptionsGuard disable_options_guard; + DisableOptionsGuard::getCurOptions().set(DisableOption::MagicZero); + preseg_passes::OptimizationPassGuard + optimization_guard(false); + + const auto size = GetParam(); + + // topk doesn't support batching, so the maximum is 1024 + if (size > 1024) { + return; + } + + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + DataType dtype = DataType::Int; + DataType dtype_extra = DataType::Float; + + std::vector shape = {size}; + + auto tv0 = makeContigConcreteTensor(shape, dtype); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = topk(tv1, fusion.oneVal(DataType::Int), 0).values; + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + // Duplicate the above call but should not change the usage as it's + // the same template instantiation + auto tv4 = set(tv0); + auto tv5 = topk(tv4, fusion.oneVal(DataType::Int), 0).values; + auto tv6 = set(tv5); + fusion.addOutput(tv6); + + // Create a different instantiation + auto tv7 = castOp(dtype_extra, tv0); + auto tv8 = topk(tv7, fusion.oneVal(DataType::Int), 0).values; + auto tv9 = set(tv8); + fusion.addOutput(tv9); + + auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::Tensor t0 = at::randint(0, shape[0], shape, options); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0}); + testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); +} + +// Simplified version of +// ScanParameterizedWithBlockandBatch.SharedMemoryRequirement. The +// test may be segmented but should not fail as long as the +// expectation of the shared memory usage is accurate. +TEST_P(GreedySchedulerTestShmemSize, Scan) { + DisableOptionsGuard disable_options_guard; + DisableOptionsGuard::getCurOptions().set(DisableOption::MagicZero); + preseg_passes::OptimizationPassGuard + optimization_guard(false); + + const auto size = GetParam(); + + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + DataType dtype = DataType::Int; + DataType dtype_extra = DataType::Float; + + std::vector shape = {size}; + + auto tv0 = makeContigConcreteTensor(shape, dtype); + fusion.addInput(tv0); + + auto tv1 = cumsum(tv0, 0); + fusion.addOutput(tv1); + + // Duplicate the above call but should not change the usage as it's + // the same template instantiation + auto tv2 = cumsum(tv0, 0); + fusion.addOutput(tv2); + + // Create a different instantiation + auto tv3 = castOp(dtype_extra, tv0); + auto tv4 = cumsum(tv3, 0); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::Tensor t0 = at::randint(0, shape[0], shape, options); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0}); + testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); +} + +INSTANTIATE_TEST_SUITE_P( + , + GreedySchedulerTestShmemSize, + testing::Values(128, 256, 512, 1024, 2048, 4096), + [](const auto& info) { + std::ostringstream os; + os << info.param; + return os.str(); + }); + +TEST_F(GreedySchedulerTest, TMP) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = set(tv1); + auto tv3 = set(tv2); + auto tv4 = set(tv3); + fusion.addOutput(tv4); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv2->setMemoryType(MemoryType::Shared); + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv3->setMemoryType(MemoryType::Shared); + tv3->axis(0)->parallelize(ParallelType::TIDx); + tv4->axis(0)->parallelize(ParallelType::TIDx); + + fusion.printKernel(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({100}, options); + + KernelExecutor ke; + ke.compile(&fusion, {t0}); + auto outputs = ke.run({t0}); + testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/tests/cpp/test_host_ir_integration.cpp b/tests/cpp/test_host_ir_integration.cpp index 01614755adb..b578cd17d8a 100644 --- a/tests/cpp/test_host_ir_integration.cpp +++ b/tests/cpp/test_host_ir_integration.cpp @@ -29,7 +29,6 @@ class HostIrIntegrationTest : public NVFuserTest { protected: HostIrIntegrationTest() { EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering); - EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrJit); } }; diff --git a/tests/cpp/test_host_ir_jit.cpp b/tests/cpp/test_host_ir_jit.cpp index 9414e07a051..a3b5f13d619 100644 --- a/tests/cpp/test_host_ir_jit.cpp +++ b/tests/cpp/test_host_ir_jit.cpp @@ -20,12 +20,7 @@ namespace nvfuser { namespace hir { -class HostIrJitTest : public NVFuserTest { - protected: - HostIrJitTest() { - EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrJit); - } -}; +using HostIrJitTest = NVFuserTest; // Build with: python setup.py install --build-with-host-ir-jit TEST_F(HostIrJitTest, Set) { auto hic = std::make_unique(); @@ -338,7 +333,7 @@ TEST_F(HostIrJitTest, Matmul) { HostIrJit jit(std::move(hic)); - auto options = at::TensorOptions().device(at::kCUDA, 0).dtype(at::kFloat); + auto options = at::TensorOptions().device(at::kCUDA, 0).dtype(torch::kFloat); at::Tensor t0 = at::randn({H, M, K}, options); at::Tensor t1 = at::randn({H, K, N}, options); at::Tensor t2 = at::randn({H, M, N}, options); @@ -382,7 +377,7 @@ TEST_F(HostIrJitTest, MatmulOut) { HostIrJit jit(std::move(hic)); - auto options = at::TensorOptions().device(at::kCUDA, 0).dtype(at::kFloat); + auto options = at::TensorOptions().device(at::kCUDA, 0).dtype(torch::kFloat); at::Tensor t0 = at::randn({H, M, K}, options); at::Tensor t1 = at::randn({H, K, N}, options); std::unordered_map concrete_input_buffers = { @@ -433,7 +428,7 @@ TEST_F(HostIrJitTest, Linear) { HostIrJit jit(std::move(hic)); - auto options = at::TensorOptions().device(at::kCUDA, 0).dtype(at::kFloat); + auto options = at::TensorOptions().device(at::kCUDA, 0).dtype(torch::kFloat); auto in_at = at::randint(5, {B, M, K}, options); auto weight_at = at::randint(5, {N, K}, options); auto bias_at = at::randint(5, {N}, options); diff --git a/tests/cpp/test_scan.cpp b/tests/cpp/test_scan.cpp index e41a392af94..5715ae4fd24 100644 --- a/tests/cpp/test_scan.cpp +++ b/tests/cpp/test_scan.cpp @@ -10,8 +10,11 @@ #include #include #include +#include +#include #include #include +#include #include #include @@ -661,4 +664,136 @@ TEST_F(ScanTest, LowPrecision) { testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); } +class ScanParameterizedWithBlock + : public ScanTest, + public ::testing::WithParamInterface> {}; + +TEST_P(ScanParameterizedWithBlock, SharedMemoryRequirement) { + DisableOptionsGuard disable_options_guard; + // Avoid using magic zero to make the estimation simpler + DisableOptionsGuard::getCurOptions().set(DisableOption::MagicZero); + // Avoid insertion of segmenter_set + preseg_passes::OptimizationPassGuard + optimization_guard(false); + + const auto [size, batch, has_duplicate, has_extra] = GetParam(); + + // This combination is not considered as the number of threads + // exceeds the limit + if (ceilDiv(size, batch) > 1024) { + return; + } + + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + DataType dtype = DataType::Int; + DataType dtype_extra = DataType::Float; + + std::vector shape = {size}; + + auto tv0 = makeContigConcreteTensor(shape, dtype); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = cumsum(tv1, 0); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + // Unlike ArgsortOp, scan passes a lambda to the CUB template + // function, so each invocation seems to be treated as a unique + // instantiation and doubles the memory usage. This should not be an + // issue once shared memory reuse is implemented. + // the same template instantiation + if (has_duplicate) { + auto tv4 = set(tv0); + auto tv5 = cumsum(tv4, 0); + auto tv6 = set(tv5); + fusion.addOutput(tv6); + } + + // Create a different instantiation + if (has_extra) { + auto tv7 = castOp(dtype_extra, tv0); + auto tv8 = cumsum(tv7, 0); + auto tv9 = set(tv8); + fusion.addOutput(tv9); + } + + for (auto tv : fusion.allTvs()) { + if (batch > 1) { + tv->split(-1, batch); + if (tv->isDefinitionType()) { + tv->axis(-1)->parallelize(ParallelType::Group); + } + } + tv->axis(0)->parallelize(ParallelType::TIDx); + } + + auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::Tensor t0 = at::randint(0, shape[0], shape, options); + + scheduler_tools::CubSharedMemoryBuffer smem_buffer; + smem_buffer.registerScan(ceilDiv(size, batch), batch, dtype); + if (has_duplicate) { + smem_buffer.registerScan(ceilDiv(size, batch), batch, dtype); + } + if (has_extra) { + smem_buffer.registerScan(ceilDiv(size, batch), batch, dtype_extra); + } + const int64_t expected_size = smem_buffer.getTotalSizeInBytes(); + + const int64_t available_capacity = + at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock; + const int64_t opt_in_available_capacity = + at::cuda::getCurrentDeviceProperties()->sharedMemPerBlockOptin; + + KernelExecutor ke; + if (expected_size <= available_capacity) { + ke.compile(&fusion, {t0}); + auto outputs = ke.run({t0}); + testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); + // Not sure why but when the block size is smaller than a warp, + // the actual size is even smaller than the estimation. + if (ceilDiv(size, batch) < 32) { + EXPECT_LE(ke.getStaticSmemSize(), expected_size) + << "Actual static shared memory size was not smaller than the " + "expectation"; + } else { + // The test would fail if the estimate is not 100% accurate. That + // may be too strict and fragile as a test. After all, we would + // just need a reasonably tight upper bound. Consider relaxing the + // condition if necessary. + EXPECT_EQ(expected_size, ke.getStaticSmemSize()) + << "Actual static shared memory size was different"; + } + } else if (expected_size > opt_in_available_capacity) { + // Compilation should fail + EXPECT_THAT( + [&]() { ke.compile(&fusion, {t0}); }, + testing::Throws()); + } else { + // It doesn't seem consistent whether compilation or launch should + // fail if the requirement of static shared memory exceeds the default + // limit but within the opt-in larger limit. As we should move to + // dynamic allocaitons anyway, don't assert for now. + } +}; + +INSTANTIATE_TEST_SUITE_P( + , + ScanParameterizedWithBlock, + testing::Combine( + testing::Values(128, 512, 1024, 2048, 4096), + testing::Values(1, 2, 3, 8), + testing::Bool(), + testing::Bool()), + [](const auto& info) { + std::ostringstream os; + os << std::get<0>(info.param) << "_" << std::get<1>(info.param) << "_" + << std::get<2>(info.param) << "_" << std::get<3>(info.param); + return os.str(); + }); + } // namespace nvfuser diff --git a/tests/cpp/test_topk.cpp b/tests/cpp/test_topk.cpp index 8aba3e78ee0..a60a3bed445 100644 --- a/tests/cpp/test_topk.cpp +++ b/tests/cpp/test_topk.cpp @@ -13,8 +13,11 @@ #include #include #include +#include +#include #include #include +#include #include #include #include @@ -683,4 +686,119 @@ TEST_F(TopKTest, BufferSync) { testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); } +class TopKParameterizedWithBlockandBatch + : public TopKTest, + public ::testing::WithParamInterface> {}; + +TEST_P(TopKParameterizedWithBlockandBatch, SharedMemoryRequirement) { + DisableOptionsGuard disable_options_guard; + // Avoid using magic zero to make the estimation simpler + DisableOptionsGuard::getCurOptions().set(DisableOption::MagicZero); + // Avoid insertion of segmenter_set + preseg_passes::OptimizationPassGuard + optimization_guard(false); + + const auto [size, batch, has_dulicate, has_extra] = GetParam(); + + ASSERT_EQ(batch, 1) << "TopKOp does not support batching yet"; + + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + DataType dtype = DataType::Int; + DataType dtype_extra = DataType::Float; + + std::vector shape = {size}; + + auto tv0 = makeContigConcreteTensor(shape, dtype); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = topk(tv1, fusion.oneVal(DataType::Int), 0).values; + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + // Duplicate the above call but should not change the usage as it's + // the same template instantiation + if (has_dulicate) { + auto tv4 = set(tv0); + auto tv5 = topk(tv4, fusion.oneVal(DataType::Int), 0).values; + auto tv6 = set(tv5); + fusion.addOutput(tv6); + } + + // Create a different instantiation + if (has_extra) { + auto tv7 = castOp(dtype_extra, tv0); + auto tv8 = topk(tv7, fusion.oneVal(DataType::Int), 0).values; + auto tv9 = set(tv8); + fusion.addOutput(tv9); + } + + for (auto tv : fusion.allTvs()) { + if (batch > 1) { + tv->split(-1, batch); + if (tv->isDefinitionType()) { + tv->axis(-1)->parallelize(ParallelType::Group); + } + } + tv->axis(0)->parallelize(ParallelType::TIDx); + } + + auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::Tensor t0 = at::randint(0, shape[0], shape, options); + + scheduler_tools::CubSharedMemoryBuffer smem_buffer; + smem_buffer.registerTopK(ceilDiv(size, batch), batch, dtype); + // The duplicate should not increase the buffer usage + if (has_extra) { + smem_buffer.registerTopK(ceilDiv(size, batch), batch, dtype_extra); + } + const int64_t expected_size = smem_buffer.getTotalSizeInBytes(); + + const int64_t available_capacity = + at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock; + const int64_t opt_in_available_capacity = + at::cuda::getCurrentDeviceProperties()->sharedMemPerBlockOptin; + + KernelExecutor ke; + if (expected_size <= available_capacity) { + ke.compile(&fusion, {t0}); + auto outputs = ke.run({t0}); + testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); + // The test would fail if the estimate is not 100% accurate. That + // may be too strict and fragile as a test. After all, we would + // just need a reasonably tight upper bound. Consider relaxing the + // condition if necessary. + EXPECT_EQ(expected_size, ke.getStaticSmemSize()) + << "Actual static shared memory size was different"; + } else if (expected_size > opt_in_available_capacity) { + // Compilation should fail + EXPECT_THAT( + [&]() { ke.compile(&fusion, {t0}); }, + testing::Throws()); + } else { + // It doesn't seem consistent whether compilation or launch should + // fail if the requirement of static shared memory exceeds the default + // limit but within the opt-in larger limit. As we should move to + // dynamic allocaitons anyway, don't assert for now. + } +}; + +INSTANTIATE_TEST_SUITE_P( + , + TopKParameterizedWithBlockandBatch, + testing::Combine( + testing::Values(128, 256, 512, 1024), + testing::Values(1), + testing::Bool(), + testing::Bool()), + [](const auto& info) { + std::ostringstream os; + os << std::get<0>(info.param) << "_" << std::get<1>(info.param) << "_" + << std::get<2>(info.param) << "_" << std::get<3>(info.param); + return os.str(); + }); + } // namespace nvfuser