Skip to content

Commit a1c9681

Browse files
author
hyun gyu kim
committed
[TIR][Schedule] Add FuseReductionEpilogue primitive to fuse epilogue into reduction init
Currently it is not possible to fuse an epilogue operation (e.g., bias addition) into a reduction block's initialization statement. This limitation prevents leveraging hardware-specific instructions that support bias accumulation in vector ISAs, such as MACC (multiply-accumulate with bias) instructions. This commit implements a new schedule primitive 'fuse_reduction_epilogue' that addresses the problem described in: https://discuss.tvm.apache.org/t/tir-problem-inlining-addition-into-matmul-block/18066 The primitive transforms the following pattern: Before: for i, j, k in T.grid(M, N, K): with T.block("matmul"): with T.init(): temp[vi, vj] = 0 temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] for i, j in T.grid(M, N): with T.block("bias_add"): D[vi, vj] = temp[vi, vj] + C[vi, vj] After: for i, j, k in T.grid(M, N, K): with T.block("matmul"): T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) T.writes(D[vi, vj]) with T.init(): D[vi, vj] = C[vi, vj] # Fused epilogue into init D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] The transformation removes the intermediate temp buffer and the separate epilogue block, enabling better tensorization opportunities for hardware with bias accumulation support. Implementation: - ReductionEpilogueFuser class for pattern validation and IR transformation - BodyPatternAllowFusion: Validates epilogue can be fused - AnalyzeEpiloguePattern: Detects addition pattern (D = temp + C) - ExtractEpilogueInfo: Extracts buffer and region information - CreateFusedReductionBlock: Creates single block with modified T.init() - SingleBlockFusionReplacer: Replaces blocks and removes temp buffer - Variable mapping between epilogue and reduction block iter vars - Proper buffer and region updates with correct read/write ordering - FFI bindings and Python API following TVM conventions Changes: - src/tir/schedule/primitive/compute_inline.cc: Core implementation (~430 lines) - src/tir/schedule/primitive.h: Function declaration - include/tvm/tir/schedule/schedule.h: Virtual method in ScheduleNode - src/tir/schedule/concrete_schedule.{h,cc}: ConcreteScheduleNode implementation - src/tir/schedule/traced_schedule.{h,cc}: TracedScheduleNode implementation - src/tir/schedule/schedule.cc: FFI binding registration - python/tvm/tir/schedule/schedule.py: Python API with documentation - tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py: Comprehensive tests including basic fusion, float32 variant, numerical correctness verification, and trace roundtrip validation Run tests with: pytest tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py -v
1 parent 03d55df commit a1c9681

File tree

10 files changed

+644
-1
lines changed

10 files changed

+644
-1
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,13 @@ class ScheduleNode : public runtime::Object {
608608
* \param block The block to be inlined to its producer
609609
*/
610610
virtual void ReverseComputeInline(const BlockRV& block) = 0;
611+
/*!
612+
* \brief Fuse an epilogue block into a reduction block
613+
* \param reduction_block The reduction block (e.g., matmul)
614+
* \param epilogue_block The epilogue block to be fused (e.g., bias add)
615+
*/
616+
virtual void FuseReductionEpilogue(const BlockRV& reduction_block,
617+
const BlockRV& epilogue_block) = 0;
611618
/******** Schedule: Reduction ********/
612619
/*!
613620
* \brief Decompose a reduction block into two separate blocks.

python/tvm/tir/schedule/schedule.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2345,6 +2345,33 @@ def after_inline(a: T.handle, c: T.handle) -> None:
23452345
# pylint: disable-next=no-member
23462346
_ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore
23472347

2348+
@type_checked
2349+
def fuse_reduction_epilogue(
2350+
self,
2351+
reduction_block: Union[BlockRV, str],
2352+
epilogue_block: Union[BlockRV, str],
2353+
) -> None:
2354+
"""Fuse an epilogue block into a reduction block.
2355+
2356+
It requires:
2357+
1) The reduction block is a complete reduction block
2358+
2) The epilogue block only reads from the reduction block's output
2359+
3) The epilogue performs a simple addition: output = reduction_result + bias
2360+
2361+
Parameters
2362+
----------
2363+
reduction_block : Union[BlockRV, str]
2364+
The reduction block (e.g., matmul)
2365+
epilogue_block : Union[BlockRV, str]
2366+
The epilogue block to be fused (e.g., bias add)
2367+
"""
2368+
reduction_block = self._normalize_block_arg(reduction_block)
2369+
epilogue_block = self._normalize_block_arg(epilogue_block)
2370+
# pylint: disable-next=no-member
2371+
_ffi_api.ScheduleFuseReductionEpilogue(
2372+
self, reduction_block, epilogue_block
2373+
) # type: ignore
2374+
23482375
########## Schedule: Reduction ##########
23492376

23502377
@type_checked

src/tir/schedule/concrete_schedule.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,15 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) {
832832
this->state_->DebugVerify();
833833
}
834834

835+
void ConcreteScheduleNode::FuseReductionEpilogue(const BlockRV& reduction_block_rv,
836+
const BlockRV& epilogue_block_rv) {
837+
TVM_TIR_SCHEDULE_BEGIN();
838+
tir::FuseReductionEpilogue(state_, this->GetSRef(reduction_block_rv),
839+
this->GetSRef(epilogue_block_rv));
840+
TVM_TIR_SCHEDULE_END("fuse-reduction-epilogue", this->error_render_level_);
841+
this->state_->DebugVerify();
842+
}
843+
835844
/******** Schedule: Block Annotation ********/
836845

837846
void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis,

src/tir/schedule/concrete_schedule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ class ConcreteScheduleNode : public ScheduleNode {
147147
int index = -1) override;
148148
void ComputeInline(const BlockRV& block) override;
149149
void ReverseComputeInline(const BlockRV& block) override;
150+
void FuseReductionEpilogue(const BlockRV& reduction_block,
151+
const BlockRV& epilogue_block) override;
150152
/******** Schedule: Reduction ********/
151153
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
152154
BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override;

src/tir/schedule/primitive.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,14 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref);
509509
* \param block_sref The sref to the block to be inlined to its producer
510510
*/
511511
TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref);
512+
/*!
513+
* \brief Fuse an epilogue block into a reduction block
514+
* \param self The state of the schedule
515+
* \param reduction_block_sref The sref to the reduction block
516+
* \param epilogue_block_sref The sref to the epilogue block to be fused
517+
*/
518+
TVM_DLL void FuseReductionEpilogue(ScheduleState self, const StmtSRef& reduction_block_sref,
519+
const StmtSRef& epilogue_block_sref);
512520
/******** Schedule: Reduction ********/
513521
/*!
514522
* \brief Decompose a reduction block into two separate blocks.

0 commit comments

Comments
 (0)