Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/nvmmh.cpp
${NVFUSER_SRCS_DIR}/scheduler/communication.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_inner.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_non_tma.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_tma.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer_tma_ws.cpp
Expand Down
6 changes: 5 additions & 1 deletion csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <ops/arith.h>
#include <options.h>
#include <scheduler/debug_utils.h>
#include <scheduler/normalization_inner_tma.h>
#include <scheduler/normalization_utils.h>
#include <transform_iter.h>
#include <transform_replay.h>
Expand Down Expand Up @@ -414,6 +415,7 @@ std::unique_ptr<SegmentedFusion> SegmentedFusion::fromCompleteFusion(
};
SegmentCandidateFinderOptions scfo;
if (scfo.run_translate_welford && isPersistentScheduler()) {
std::cout << "translateWelfordInFusion" << std::endl;
SegmentCandidateFinder::translateWelfordInFusion(fusion, runtime_inputs);
}

Expand Down Expand Up @@ -2834,8 +2836,10 @@ bool TranslateApplicableWelford::isValidPersistentFusion(
// However, when it comes to cross grid reduction, the additional grid
// synchronization carries substantial overhead and does not yield any
// performance gains.
return heuristic_params->as<ReductionParams>()->persistent_kernel &&
bool is_non_tma_persistent =
heuristic_params->as<ReductionParams>()->persistent_kernel &&
!heuristic_params->as<ReductionParams>()->cross_grid_outer_reduction;
return is_non_tma_persistent || heuristic_params->isA<InnerNormTmaParams>();
}

// Note that when segmented it is assumed that insertion of lower
Expand Down
1 change: 1 addition & 0 deletions csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ const std::unordered_map<std::string, EnableOption>& getEnableOptions() {
{"wait_debugger", EnableOption::WaitDebugger},
{"warn_register_spill", EnableOption::WarnRegisterSpill},
{"tma_pointwise", EnableOption::TmaPointwise},
{"tma_inner_persistent", EnableOption::TmaInnerPersistent},
{"ws_normalization", EnableOption::WarpSpecializedNormalization},
{"host_ir_lowering", EnableOption::HostIrLowering},
{"host_ir_jit", EnableOption::HostIrJit},
Expand Down
1 change: 1 addition & 0 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ enum class EnableOption {
// will wait for `gdb attach` at the start.
WarnRegisterSpill, //! Enable warnings of register spill
TmaPointwise, //! Enable TMA pointwise kernel
TmaInnerPersistent, //! Enable TMA inner persistent kernel
WarpSpecializedNormalization, //! Enable warp specialized persistent kernel
HostIrLowering, //! Enable FusionKernelRuntime lowering to host IR
HostIrJit, //! Enable Host IR JIT compilation with LLVM
Expand Down
3 changes: 3 additions & 0 deletions csrc/parallel_dimension_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ void ParallelDimensionMap::adjustMappingsForWarpSpecialization() {
"divisible with 128 threads.");
int64_t ws_num_threads_pad = 128 / other_active_pts_threads;
int64_t after_pad = getThreadCountInDim(ws_pt) + ws_num_threads_pad;
std::cout << "after_pad: " << after_pad
<< ", other_active_pts_threads: " << other_active_pts_threads
<< ", ws_num_threads_pad: " << ws_num_threads_pad << std::endl;
NVF_ERROR(
(after_pad * other_active_pts_threads) % 128 == 0,
"Illegal register sharing on ",
Expand Down
3 changes: 2 additions & 1 deletion csrc/predicate_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,9 +769,10 @@ OneDimTmaPredicateInfo PredicateCompute::OneDimTmaLoadExpectArrive(
// domain is divisible.
replace_map[fl->index()] = GpuLower::current()->kernel()->zeroVal();
auto id_def = fl->iter_domain()->definition();
if (!id_def) {
if (!id_def || fl->iter_domain()->isBlockDim()) {
continue;
}

if (auto split = dynamic_cast<Split*>(id_def)) {
GpuLower::current()->validate(
split->isDivisible(),
Expand Down
Loading
Loading