Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@ if(TRITON_BUILD_PYTHON_MODULE)
add_subdirectory(third_party/proton)
endif()

# TLX (Triton Low-level Language Extensions) for warp specialization
# NOTE: TLX C++ dialect is disabled pending MLIR integration
# The Python TLX API still works without the C++ dialect for config-based features
# add_subdirectory(third_party/tlx/dialect)

get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
set(TRITON_LIBRARIES
Expand Down
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ FlagTree is an open source, unified compiler for multiple AI chips project dedic
Each backend is based on different versions of triton, and therefore resides in different protected branches ([main](https://github.com/flagos-ai/flagtree/tree/main) for triton 3.1, [triton_v3.2.x](https://github.com/flagos-ai/flagtree/tree/triton_v3.2.x), [triton_v3.3.x](https://github.com/flagos-ai/flagtree/tree/triton_v3.3.x), [triton_v3.4.x](https://github.com/flagos-ai/flagtree/tree/triton_v3.4.x), [triton_v3.5.x](https://github.com/flagos-ai/flagtree/tree/triton_v3.5.x)). All these protected branches have equal status. <br>

## Latest News
* 2026/01/28 **NEW** Added `@auto_pipeline` decorator for automatic multi-level pipelining optimization with up to 1.93x speedup on GEMM.
* 2025/12/24 Support pull and install [whl](/README.md#non-source-installation).
* 2025/12/08 Added [enflame](https://github.com/FlagTree/flagtree/tree/triton_v3.3.x/third_party/enflame/) backend integration (based on Triton 3.3), and added CI/CD.
* 2025/11/26 Add FlagTree_Backend_Specialization Unified Design Document [FlagTree_Backend_Specialization](reports/decoupling/).
Expand Down Expand Up @@ -37,6 +38,55 @@ Each backend is based on different versions of triton, and therefore resides in
* 2025/03/19 Added [mthreads](https://github.com/FlagTree/flagtree/tree/main/third_party/mthreads/) backend integration (based on Triton 3.1), and added CI/CD.
* 2025/03/12 Added [iluvatar](https://github.com/FlagTree/flagtree/tree/main/third_party/iluvatar/) backend integration (based on Triton 3.1), and added CI/CD.

## AutoPipeline: Advanced Multi-Level Pipelining

FlagTree introduces `@auto_pipeline`, a decorator that enables automatic multi-level pipelining optimization for Triton kernels. This feature provides significant performance improvements without requiring manual kernel modifications.

### Features

- **Global-to-Shared (G2S) Pipelining**: Multi-stage async data prefetching from global to shared memory
- **Shared-to-Register (S2R) Pipelining**: Double-buffering optimization for shared memory to register transfers
- **Warp Specialization**: Producer-consumer pattern with dedicated prefetch and compute warps

### Quick Start

```python
import triton
import triton.language as tl
from triton.language import auto_pipeline, PipelineConfig, WarpSpecConfig

@triton.jit
@auto_pipeline(PipelineConfig(
global_to_shared_stages=4,
shared_to_register_stages=2,
enable_async_copy=True,
enable_swizzle=True,
enable_warp_specialization=True,
warp_spec_config=WarpSpecConfig(
num_producer_warps=1,
num_consumer_warps=3,
)
))
def matmul_kernel(A, B, C, M, N, K, ...):
# Standard GEMM implementation - no changes needed!
...
```

### Performance Results (2048x2048x2048 GEMM on A100)

| Kernel | TFLOPS | Speedup |
|--------|--------|---------|
| No Pipeline | 107.30 | 1.00x |
| Default Pipeline | 167.41 | 1.56x |
| **AutoPipeline** | **206.79** | **1.93x** |

### Run Benchmark

```shell
cd python/test
python benchmark_autopipeline.py
```

## Install from source
Installation dependencies (ensure you use the correct python3.x version):
```shell
Expand Down
3 changes: 3 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ std::unique_ptr<OperationPass<ModuleOp>> createAllocateSharedMemoryPass();

} // namespace gpu

// Forward declaration for pipeline intrinsics pass
std::unique_ptr<OperationPass<ModuleOp>> createPipelineIntrinsicsToLLVMPass();

#define GEN_PASS_REGISTRATION
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"

Expand Down
14 changes: 14 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,18 @@ def AllocateSharedMemory : Pass<"allocate-shared-memory", "mlir::ModuleOp"> {
let constructor = "mlir::triton::gpu::createAllocateSharedMemoryPass()";
}

def ConvertPipelineIntrinsicsToLLVM : Pass<"convert-pipeline-intrinsics-to-llvm", "mlir::ModuleOp"> {
let summary = "Lower pipeline synchronization intrinsics to LLVM/NVVM";
let description = [{
Converts pipeline intrinsics (init, acquire, commit, wait, release, flush)
and async copy operations to LLVM IR primitives. On NVIDIA GPUs, this maps
to cp.async and barrier operations. Provides fallback for non-NVIDIA backends.
}];
let constructor = "mlir::triton::createPipelineIntrinsicsToLLVMPass()";
let dependentDialects = [
"mlir::LLVM::LLVMDialect",
"mlir::NVVM::NVVMDialect"
];
}

#endif
27 changes: 27 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/AdvancedPipeliner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_ADVANCEDPIPELINER_H
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_ADVANCEDPIPELINER_H

#include "triton/Dialect/TritonGPU/Transforms/BufferAccessAnalysis.h"
#include "triton/Dialect/TritonGPU/Transforms/PipelineOpportunityDetector.h"
#include "triton/Dialect/TritonGPU/Transforms/CircularBufferTransform.h"
#include "triton/Dialect/TritonGPU/Transforms/SynchronizationInsertion.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace triton {
namespace gpu {

// Forward declaration - actual pass class is generated by TableGen
// See: triton/Dialect/TritonGPU/Transforms/Passes.td
// The TableGen-generated pass is: impl::TritonGPUAdvancedPipelinerBase

/// Create the advanced pipeliner pass
/// Note: This function is auto-generated by TableGen
// std::unique_ptr<Pass> createAdvancedPipelinerPass();
// Use createTritonGPUAdvancedPipeliner() instead (from TableGen)

} // namespace gpu
} // namespace triton
} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_ADVANCEDPIPELINER_H
139 changes: 139 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/BufferAccessAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_BUFFERACCESSANALYSIS_H
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_BUFFERACCESSANALYSIS_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <memory>

namespace mlir {
namespace triton {
namespace gpu {

/// Memory scope classification for buffers
enum class MemoryScope {
Global, /// Global memory (HBM/VRAM)
Shared, /// Shared memory (SMEM/LDS)
Register, /// Register file
Unknown /// Cannot determine
};

/// Information about how a buffer is accessed
struct BufferAccessInfo {
/// The buffer value being accessed
Value buffer;

/// Memory scope of the buffer
MemoryScope scope;

/// Operation that produces/writes to this buffer (may be null)
Operation *producer;

/// Operations that consume/read from this buffer
SmallVector<Operation *> consumers;

/// First access in program order
Operation *firstAccess;

/// Last access in program order
Operation *lastAccess;

/// Enclosing loop context (if accessed within a loop)
scf::ForOp loopContext;

/// Lowest common ancestor of all accesses
Operation *lca;

/// Access pattern metadata
bool isSequential;
bool isStrided;
int64_t stride;
int64_t elementCount;

/// Element type of the buffer
Type elementType;

/// Predecessor buffer (for data flow tracking)
Value predecessorBuffer;

/// Enhanced: Block pointer tracking
bool isBlockPtr;

/// Enhanced: Global→Shared transfer detection
bool isGlobalToShared;

BufferAccessInfo()
: buffer(nullptr), scope(MemoryScope::Unknown), producer(nullptr),
firstAccess(nullptr), lastAccess(nullptr), loopContext(nullptr),
lca(nullptr), isSequential(false), isStrided(false), stride(1),
elementCount(0), elementType(nullptr), predecessorBuffer(nullptr),
isBlockPtr(false), isGlobalToShared(false) {}
};

/// Analysis pass for tracking buffer accesses and dependencies
class BufferAccessAnalysis {
public:
BufferAccessAnalysis() = default;

/// Run analysis on a function
void analyze(triton::FuncOp function);

/// Get access information for a buffer
BufferAccessInfo *getAccessInfo(Value buffer);

/// Get all buffers accessed within a loop
SmallVector<Value> getBuffersInLoop(scf::ForOp loop);

/// Check if a buffer can be pipelined
bool isPipelinable(Value buffer);

/// Compute the lowest common ancestor of all buffer accesses
Operation *computeLCA(Value buffer);

/// Clear all analysis results
void clear();

private:
/// Map from buffer to access information
DenseMap<Value, std::unique_ptr<BufferAccessInfo>> bufferInfoMap;

/// Map from block pointer to base pointer (for tracking global memory sources)
DenseMap<Value, Value> blockPtrMap;

/// Current loop nesting during traversal
SmallVector<scf::ForOp> loopStack;

/// Operation nesting for LCA computation
SmallVector<Operation *> opStack;

/// Visitor functions
void visitOperation(Operation *op);
void visitAllocation(Operation *allocOp);
void visitLoad(Operation *loadOp);
void visitStore(Operation *storeOp);
void visitForLoop(scf::ForOp forOp);

/// Enhanced visitor functions for block pointers and shared memory
void visitMakeTensorPtr(Operation *makeTensorPtrOp);
void visitLocalAlloc(Operation *localAllocOp);
void visitLocalLoad(Operation *localLoadOp);
void visitLocalStore(Operation *localStoreOp);

/// Helper functions
Value getBaseBuffer(Value ptr);
MemoryScope determineMemoryScope(Value buffer);
void analyzeAccessPattern(Operation *memOp, BufferAccessInfo *info);
Operation *findLowestCommonAncestor(Operation *op1, Operation *op2);
bool hasMemoryDependency(BufferAccessInfo *info);
};

} // namespace gpu
} // namespace triton
} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_BUFFERACCESSANALYSIS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_CIRCULARBUFFERTRANSFORM_H
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_CIRCULARBUFFERTRANSFORM_H

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/PipelineOpportunityDetector.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Value.h"

namespace mlir {
namespace triton {
namespace gpu {

/// Information about a circular buffer transformation
struct CircularBufferInfo {
/// Original buffer allocation
Value originalBuffer;

/// New circular buffer (expanded with stage dimension)
Value circularBuffer;

/// Number of pipeline stages
unsigned numStages;

/// Stride in elements between stages
int64_t stride;

/// Loop being pipelined
scf::ForOp loop;

/// Associated pipeline ID
unsigned pipelineId;

/// Use async copy intrinsics
bool useAsyncCopy;

/// Use swizzled indexing
bool useSwizzle;

CircularBufferInfo()
: originalBuffer(nullptr), circularBuffer(nullptr), numStages(1),
stride(0), loop(nullptr), pipelineId(0), useAsyncCopy(false),
useSwizzle(false) {}
};

/// Transform buffer allocations and accesses to use circular buffering
class CircularBufferTransform {
public:
explicit CircularBufferTransform(OpBuilder &builder) : builder(builder) {}

/// Transform a buffer allocation to circular buffer
CircularBufferInfo transformAllocation(const PipelineOpportunity &opp,
unsigned pipelineId);

/// Transform a store operation to use circular buffer indexing
void transformStore(Operation *storeOp, CircularBufferInfo &info);

/// Transform a load operation to use circular buffer indexing
void transformLoad(Operation *loadOp, CircularBufferInfo &info);

/// Transform a LocalStoreOp (Global→Shared or Register→Shared)
void transformLocalStore(Operation *localStoreOp, CircularBufferInfo &info);

/// Transform a LocalLoadOp (Shared→Register) for register pipelining
void transformLocalLoad(Operation *localLoadOp, CircularBufferInfo &info);

/// Transform a global LoadOp to use async copy (Global→Shared pipelining)
/// This is the key method that generates cp.async operations
void transformGlobalLoad(triton::LoadOp loadOp, CircularBufferInfo &info,
Value insertIdx, Value extractIdx);

/// Allocate shared memory buffer for a load operation
Value allocateSharedBuffer(triton::LoadOp loadOp, unsigned numStages);

/// Get appropriate shared encoding for a load type
Attribute getSharedEncodingForLoad(triton::LoadOp loadOp);

private:
OpBuilder &builder;

/// Compute circular buffer offset for store (producer side)
/// Formula: ((global_iter + numStages - 1) % numStages) * stride
Value computeCircularOffsetStore(Location loc, Value globalIter,
unsigned numStages, int64_t stride);

/// Compute circular buffer offset for load (consumer side)
/// Formula: (global_iter % numStages) * stride
Value computeCircularOffsetLoad(Location loc, Value globalIter,
unsigned numStages, int64_t stride);

/// Compute global iteration number from potentially nested loops
Value computeGlobalIteration(scf::ForOp loop);

/// Decompose pointer into base and indices
std::pair<Value, SmallVector<Value>> decomposePointer(Value ptr);

/// Build new pointer with circular buffer dimension
Value buildPointer(Value baseBuffer, ArrayRef<Value> indices);

/// Apply swizzle pattern to reduce bank conflicts
Value applySwizzle(Value ptr, CircularBufferInfo &info);

/// Substitute loop variable with new value in an operation tree
void substituteLoopVariable(Operation *op, Value oldVar, Value newVar);
};

} // namespace gpu
} // namespace triton
} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_CIRCULARBUFFERTRANSFORM_H
Loading