You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[TIR][Schedule] Add FuseReductionEpilogue primitive to fuse epilogue into reduction init
Currently it is not possible to fuse an epilogue operation (e.g., bias addition)
into a reduction block's initialization statement. This limitation prevents
leveraging hardware-specific instructions that support bias accumulation in
vector ISAs, such as MACC (multiply-accumulate with bias) instructions.
This commit implements a new schedule primitive 'fuse_reduction_epilogue' that
addresses the problem described in:
https://discuss.tvm.apache.org/t/tir-problem-inlining-addition-into-matmul-block/18066
The primitive transforms the following pattern:
Before:
for i, j, k in T.grid(M, N, K):
with T.block("matmul"):
with T.init():
temp[vi, vj] = 0
temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
for i, j in T.grid(M, N):
with T.block("bias_add"):
D[vi, vj] = temp[vi, vj] + C[vi, vj]
After:
for i, j, k in T.grid(M, N, K):
with T.block("matmul"):
T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
T.writes(D[vi, vj])
with T.init():
D[vi, vj] = C[vi, vj] # Fused epilogue into init
D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk]
The transformation removes the intermediate temp buffer and the separate
epilogue block, enabling better tensorization opportunities for hardware
with bias accumulation support.
Implementation:
- ReductionEpilogueFuser class for pattern validation and IR transformation
- BodyPatternAllowFusion: Validates epilogue can be fused
- AnalyzeEpiloguePattern: Detects addition pattern (D = temp + C)
- ExtractEpilogueInfo: Extracts buffer and region information
- CreateFusedReductionBlock: Creates single block with modified T.init()
- SingleBlockFusionReplacer: Replaces blocks and removes temp buffer
- Variable mapping between epilogue and reduction block iter vars
- Proper buffer and region updates with correct read/write ordering
- FFI bindings and Python API following TVM conventions
Changes:
- src/tir/schedule/primitive/compute_inline.cc: Core implementation (~430 lines)
- src/tir/schedule/primitive.h: Function declaration
- include/tvm/tir/schedule/schedule.h: Virtual method in ScheduleNode
- src/tir/schedule/concrete_schedule.{h,cc}: ConcreteScheduleNode implementation
- src/tir/schedule/traced_schedule.{h,cc}: TracedScheduleNode implementation
- src/tir/schedule/schedule.cc: FFI binding registration
- python/tvm/tir/schedule/schedule.py: Python API with documentation
- tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py:
Comprehensive tests including basic fusion, float32 variant, numerical
correctness verification, and trace roundtrip validation
Run tests with:
pytest tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py -v
0 commit comments