diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp index 8ef8d72e4b2a..4bea012ff2fc 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp @@ -171,11 +171,20 @@ struct AttentionOpConversion loc, rewriter.getF32Type(), dimInt); Value scale = rewriter.createOrFold(loc, dimFloat); + int64_t numBatches = op.getQueryType().getRank() - 2; + + // When the TMTensor op is marked causal, fuse the mask into the + // attention region body using iree_linalg_ext.index ops and drop the + // materialized mask operand. + bool causal = op.getIsCausal().value_or(false); + if (causal) { + optionalMask = std::nullopt; + } + // Add batches to standard attention indexing maps. SmallVector indexingMaps = getStandardAttentionIndexingMaps(ctx, optionalMask.has_value()); - int64_t numBatches = op.getQueryType().getRank() - 2; for (AffineMap &map : indexingMaps) { map = map.shiftDims(numBatches); if (map.getNumResults() == 0) { @@ -196,7 +205,29 @@ struct AttentionOpConversion block->addArgument(rewriter.getF32Type(), loc); rewriter.setInsertionPoint(block, block->begin()); - IREE::LinalgExt::YieldOp::create(rewriter, loc, block->getArgument(0)); + if (causal) { + // In the standard layout after shiftDims(numBatches): + // m = numBatches, k2 = numBatches + 3. + int64_t mDim = numBatches; + int64_t k2Dim = numBatches + 3; + + Value mIdx = IREE::LinalgExt::IndexOp::create( + rewriter, loc, rewriter.getIndexType(), mDim); + Value k2Idx = IREE::LinalgExt::IndexOp::create( + rewriter, loc, rewriter.getIndexType(), k2Dim); + Value cmp = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ugt, k2Idx, mIdx); + // Use the element type of the score (f32). + Value negInf = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr(rewriter.getF32Type(), -INFINITY)); + Value score = block->getArgument(0); + Value masked = + arith::SelectOp::create(rewriter, loc, cmp, negInf, score); + IREE::LinalgExt::YieldOp::create(rewriter, loc, masked); + } else { + IREE::LinalgExt::YieldOp::create(rewriter, loc, block->getArgument(0)); + } } rewriter.replaceOp(op, attention.getResult(0)); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index c9a294df22de..1aacb7c351ef 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -187,7 +187,8 @@ static Value computeMatmul(OpBuilder &builder, Location loc, AffineMap lhsMap, } static Value applyPostQKMatmulElementwise(OpBuilder &builder, Location loc, - Region ®ion, Value value) { + Region ®ion, AffineMap sMap, + Value value) { auto rank = cast(value.getType()).getRank(); AffineMap identityMap = AffineMap::getMultiDimIdentityMap(rank, builder.getContext()); @@ -199,6 +200,34 @@ static Value applyPostQKMatmulElementwise(OpBuilder &builder, Location loc, value, indexingMaps, iteratorTypes); auto &dstRegion = genericOp.getRegion(); builder.cloneRegionBefore(region, dstRegion, dstRegion.end()); + + // Build a mapping from attention iteration domain dim -> S tensor dim. + // The linalg.generic uses an identity map over S, so linalg iteration + // dim i == S tensor dim i. + DenseMap attentionDimToSDim; + for (auto [sIdx, expr] : llvm::enumerate(sMap.getResults())) { + attentionDimToSDim[cast(expr).getPosition()] = sIdx; + } + + // Replace iree_linalg_ext.index ops with linalg.index ops. + SmallVector indexOps; + for (auto indexOp : dstRegion.back().getOps()) { + indexOps.push_back(indexOp); + } + { + OpBuilder::InsertionGuard guard(builder); + for (auto indexOp : indexOps) { + auto it = attentionDimToSDim.find(indexOp.getDim()); + assert(it != attentionDimToSDim.end() && + "index op dim not found in S map"); + builder.setInsertionPoint(indexOp); + Value linalgIdx = + linalg::IndexOp::create(builder, loc, it->second)->getResult(0); + indexOp.replaceAllUsesWith(linalgIdx); + indexOp.erase(); + } + } + { OpBuilder::InsertionGuard withinRegion(builder); builder.setInsertionPoint(dstRegion.back().getTerminator()); @@ -350,7 +379,7 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, s.getDefiningOp()->setAttrs(qkAttrs); } - s = applyPostQKMatmulElementwise(b, loc, elementwiseRegion, s); + s = applyPostQKMatmulElementwise(b, loc, elementwiseRegion, sMap, s); if (lowPrecision) { // For low bit-depth types we perform post Q @ K scaling. This is to avoid diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index c32c7589fb84..50e07625d403 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -3097,19 +3097,25 @@ CustomOp::reifyResultShapes(OpBuilder &builder, LogicalResult IREE::LinalgExt::IndexOp::verify() { auto parentOp = getOperation()->getParentOp(); - if (!isa(parentOp)) { + if (!isa(parentOp)) { return emitOpError( "expected parent op to be one of `iree_linalg_ext.custom_op`, " - "`iree_linalg_ext.attention`"); + "`iree_linalg_ext.attention`, `iree_linalg_ext.online_attention`"); } - auto customOp = dyn_cast(parentOp); - auto attentionOp = dyn_cast(parentOp); int64_t numLoops = - customOp ? customOp.getNumLoops() : attentionOp.getNumLoops(); + TypeSwitch(parentOp) + .Case( + [](CustomOp op) -> int64_t { return op.getNumLoops(); }) + .Case([](AttentionOp op) -> int64_t { + return op.getIterationDomainRank(); + }) + .Case([](OnlineAttentionOp op) -> int64_t { + return op.getIterationDomainRank(); + }); if (numLoops <= getDim()) { return emitOpError("expected dim (") << getDim() << ") to be lower than the number of loops (" << numLoops - << ") of the enclosing CustomOp/AttentionOp"; + << ") of the enclosing operation"; } return success(); } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtPureOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtPureOps.td index e6dc3c0393a3..4333439f611d 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtPureOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtPureOps.td @@ -36,7 +36,8 @@ def IREELinalgExt_IndexOp : IREELinalgExt_PureOp<"index", [Pure]>, This operation is a mirror of `linalg.index` operation and has the same semantics, except that `linalg.index` enforces that the parent op is a `LinalgOp`, and the `iree_linalg_ext.index` operation enforces that the - parent op is one of `IREE::LinalgExt::CustomOp` or `IREE::LinalgExt::AttentionOp`. + parent op is one of `IREE::LinalgExt::CustomOp`, + `IREE::LinalgExt::AttentionOp`, or `IREE::LinalgExt::OnlineAttentionOp`. }]; let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp index 35b0ea09019c..efab54105420 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp @@ -2770,6 +2770,28 @@ SmallVector AttentionOp::getLoopIteratorTypes() { getKeyMap(), getValueMap(), getOutputMap()); } +static void offsetAttentionIndices(OpBuilder &b, Region &body, + ArrayRef offsets) { + IRRewriter rewriter(b); + for (auto indexOp : body.getOps()) { + if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) { + continue; + } + OpBuilder::InsertionGuard guard(b); + rewriter.setInsertionPointAfter(indexOp); + AffineExpr index, offset; + bindDims(b.getContext(), index, offset); + OpFoldResult applied = affine::makeComposedFoldedAffineApply( + rewriter, indexOp.getLoc(), index + offset, + {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]}); + Value materialized = + getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied); + rewriter.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) { + return use.getOwner() != materialized.getDefiningOp(); + }); + } +} + FailureOr AttentionOp::getTiledImplementation(OpBuilder &builder, ArrayRef offsets, @@ -2777,12 +2799,6 @@ AttentionOp::getTiledImplementation(OpBuilder &builder, assert(offsets.size() == getIterationDomainRank()); assert(sizes.size() == getIterationDomainRank()); - // TODO: Add support for linalg_ext.index operations in the region. - // Currently, tiling will break if index operations are present. - if (!getBody()->getOps().empty()) { - return failure(); - } - Location loc = getLoc(); SmallVector querySlice = @@ -2847,6 +2863,7 @@ AttentionOp::getTiledImplementation(OpBuilder &builder, Operation *tiledOp = mlir::clone(builder, getOperation(), resultTypes, tiledOperands); + offsetAttentionIndices(builder, tiledOp->getRegion(0), offsets); return TilingResult{ {tiledOp}, SmallVector(tiledOp->getResults()), slices}; @@ -3006,6 +3023,7 @@ OnlineAttentionOp::getTiledImplementation(OpBuilder &builder, Operation *tiledOp = mlir::clone(builder, getOperation(), resultTypes, tiledOperands); + offsetAttentionIndices(builder, tiledOp->getRegion(0), offsets); return TilingResult{ {tiledOp}, SmallVector(tiledOp->getResults()), slices}; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir index 2902eb2feac3..559ec215b1f7 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir @@ -1781,20 +1781,48 @@ func.func @custom_op_yield_type_mismatch(%arg0 : tensor, %arg1 : tensor<1 // ----- func.func @index_op_outside_custom_op() -> index { - // expected-error @+1 {{expected parent op to be one of `iree_linalg_ext.custom_op`, `iree_linalg_ext.attention`}} + // expected-error @+1 {{expected parent op to be one of `iree_linalg_ext.custom_op`, `iree_linalg_ext.attention`, `iree_linalg_ext.online_attention`}} %0 = iree_linalg_ext.index 0 : index return %0 : index } // ----- +func.func @index_op_invalid_dim_online_attention( + %query: tensor<192x1024x64xf16>, + %key: tensor<192x1024x64xf16>, + %value: tensor<192x1024x64xf16>, + %output: tensor<192x1024x64xf32>, + %max: tensor<192x1024xf32>, + %sum: tensor<192x1024xf32>) -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) { + %scale = arith.constant 1.0 : f32 + %out:3 = iree_linalg_ext.online_attention + {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>]} + ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f32) + outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) { + ^bb0(%score: f32): + // expected-error @+1 {{expected dim (5) to be lower than the number of loops (5) of the enclosing operation}} + %idx = iree_linalg_ext.index 5 : index + iree_linalg_ext.yield %score : f32 + } -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> + return %out#0, %out#1, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> +} + +// ----- + func.func @index_op_invalid_dim(%arg0 : tensor) -> tensor { %0 = iree_linalg_ext.custom_op { indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = [#iree_linalg_ext.iterator_type]} outs(%arg0: tensor) { ^bb0(%b0 : tensor): - // expected-error @+1 {{expected dim (1) to be lower than the number of loops (1) of the enclosing CustomOp}} + // expected-error @+1 {{expected dim (1) to be lower than the number of loops (1) of the enclosing operation}} %1 = iree_linalg_ext.index 1 : index %2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>], diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir index ac54caa453eb..322971c2f5a1 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir @@ -1982,6 +1982,40 @@ module { // ----- +func.func @attention_causal(%arg0: tensor<192x1024x64xf32>, %arg1: tensor<192x1024x64xf32>, %arg2: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> { + %cst = arith.constant dense<0.000000e+00> : tensor<192x1024x64xf32> + %scale = arith.constant 1.000000e+00 : f32 + %0 = iree_linalg_ext.attention {indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + ] + } ins(%arg0, %arg1, %arg2, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%cst : tensor<192x1024x64xf32>) { + ^bb0(%score: f32): + %m = iree_linalg_ext.index 1 : index + %k2 = iree_linalg_ext.index 3 : index + %cmp = arith.cmpi ugt, %k2, %m : index + %neg_inf = arith.constant 0xFF800000 : f32 + %masked = arith.select %cmp, %neg_inf, %score : f32 + iree_linalg_ext.yield %masked : f32 + } -> tensor<192x1024x64xf32> + return %0 : tensor<192x1024x64xf32> +} + +// CHECK-LABEL: func.func @attention_causal( +// CHECK: iree_linalg_ext.attention +// CHECK: ^bb0(%[[SCORE:.+]]: f32): +// CHECK: %[[M:.+]] = iree_linalg_ext.index 1 : index +// CHECK: %[[K2:.+]] = iree_linalg_ext.index 3 : index +// CHECK: %[[CMP:.+]] = arith.cmpi ugt, %[[K2]], %[[M]] : index +// CHECK: %[[NEG_INF:.+]] = arith.constant 0xFF800000 : f32 +// CHECK: %[[MASKED:.+]] = arith.select %[[CMP]], %[[NEG_INF]], %[[SCORE]] : f32 +// CHECK: iree_linalg_ext.yield %[[MASKED]] : f32 + +// ----- + func.func @custom_op_default(%arg0 : tensor, %arg1 : tensor) -> tensor { %0 = iree_linalg_ext.custom_op { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir index 24f953d03ae1..2fe2ccfa4702 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir @@ -40,3 +40,40 @@ func.func @attention(%q: tensor<2x10x4096x128xf16>, %k: tensor<2x10x4096x128xf16 // CHECK: arith.mulf // CHECK: arith.truncf // CHECK: linalg.yield + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + +func.func @attention_causal(%q: tensor<2x10x4096x128xf16>, %k: tensor<2x10x4096x128xf16>, %v: tensor<2x10x4096x128xf16>) + -> tensor<2x10x4096x128xf16> { + %scale = arith.constant 0.125 : f16 + %acc = tensor.empty() : tensor<2x10x4096x128xf16> + %out = iree_linalg_ext.attention + {indexing_maps = [#map, #map1, #map2, #map3, #map4]} + ins(%q, %k, %v, %scale : tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, f16) + outs(%acc : tensor<2x10x4096x128xf16>) { + ^bb0(%score: f32): + %m = iree_linalg_ext.index 2 : index + %k2 = iree_linalg_ext.index 5 : index + %cmp = arith.cmpi ugt, %k2, %m : index + %neg_inf = arith.constant 0xFF800000 : f32 + %masked = arith.select %cmp, %neg_inf, %score : f32 + iree_linalg_ext.yield %masked : f32 + } -> tensor<2x10x4096x128xf16> + func.return %out : tensor<2x10x4096x128xf16> +} + +// CHECK-LABEL: func.func @attention_causal +// CHECK: iree_linalg_ext.online_attention +// CHECK-NEXT: ^{{.+}}(%[[SCORE:.+]]: f32): +// CHECK-NEXT: %[[M:.+]] = iree_linalg_ext.index 2 : index +// CHECK-NEXT: %[[K2:.+]] = iree_linalg_ext.index 5 : index +// CHECK-NEXT: %[[CMP:.+]] = arith.cmpi ugt, %[[K2]], %[[M]] : index +// CHECK-NEXT: %[[NEG_INF:.+]] = arith.constant 0xFF800000 : f32 +// CHECK-NEXT: %[[MASKED:.+]] = arith.select %[[CMP]], %[[NEG_INF]], %[[SCORE]] : f32 +// CHECK-NEXT: iree_linalg_ext.yield %[[MASKED]] : f32 diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregated_ops.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregated_ops.mlir index 313d02735974..1c975b19ad11 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregated_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregated_ops.mlir @@ -89,3 +89,59 @@ func.func @online_attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // CHECK: arith.mulf // CHECK: arith.addf // CHECK: linalg.yield + +// ----- + +// Test that iree_linalg_ext.index ops in the attention region are remapped +// to linalg.index ops in the decomposed output (causal masking pattern). + +#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> +#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> +#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> +#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> +#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> + +func.func @online_attention_causal( + %query: tensor<4x1024x64xf16>, + %key: tensor<4x1024x64xf16>, + %value: tensor<4x1024x64xf16>, + %output: tensor<4x1024x64xf32>, + %max: tensor<4x1024xf32>, + %sum: tensor<4x1024xf32>) + -> (tensor<4x1024x64xf32>, tensor<4x1024xf32>, tensor<4x1024xf32>) { + %scale = arith.constant 1.0 : f32 + %out:3 = iree_linalg_ext.online_attention + {indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR]} + ins(%query, %key, %value, %scale : tensor<4x1024x64xf16>, tensor<4x1024x64xf16>, tensor<4x1024x64xf16>, f32) + outs(%output, %max, %sum : tensor<4x1024x64xf32>, tensor<4x1024xf32>, tensor<4x1024xf32>) { + ^bb0(%score: f32): + %m = iree_linalg_ext.index 1 : index + %k2 = iree_linalg_ext.index 3 : index + %cmp = arith.cmpi ugt, %k2, %m : index + %neg_inf = arith.constant 0xFF800000 : f32 + %masked = arith.select %cmp, %neg_inf, %score : f32 + iree_linalg_ext.yield %masked : f32 + } -> tensor<4x1024x64xf32>, tensor<4x1024xf32>, tensor<4x1024xf32> + return %out#0, %out#1, %out#2 : tensor<4x1024x64xf32>, tensor<4x1024xf32>, tensor<4x1024xf32> +} + +// CHECK-LABEL: @online_attention_causal +// S = Q @ K +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK: arith.extf +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// S = S * scale (pre-applied to Q) +// Post QK matmul elementwise (the causal masking region): +// iree_linalg_ext.index ops should be remapped to linalg.index ops. +// sMap = (batch, m, k1, k2, n) -> (batch, m, k2) +// So attention dim 1 (m) -> S dim 1, attention dim 3 (k2) -> S dim 2 +// CHECK: linalg.generic +// CHECK: %[[M_IDX:.+]] = linalg.index 1 +// CHECK: %[[K2_IDX:.+]] = linalg.index 2 +// CHECK: arith.cmpi ugt, %[[K2_IDX]], %[[M_IDX]] +// CHECK: arith.select +// CHECK: linalg.yield diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir index 583d81071d21..d4b066131e87 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir @@ -2307,6 +2307,50 @@ module attributes { transform.with_named_sequence } { // ----- +func.func @attention_causal(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> { + %0 = tensor.empty() : tensor<192x1024x64xf32> + %scale = arith.constant 1.0 : f32 + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) { + ^bb0(%score: f32): + %m = iree_linalg_ext.index 1 : index + %k2 = iree_linalg_ext.index 3 : index + %cmp = arith.cmpi ugt, %k2, %m : index + %neg_inf = arith.constant 0xFF800000 : f32 + %masked = arith.select %cmp, %neg_inf, %score : f32 + iree_linalg_ext.yield %masked : f32 + } -> tensor<192x1024x64xf32> + return %1 : tensor<192x1024x64xf32> +} +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [10, 30] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// Verify that index ops are preserved in the tiled attention and offsets are applied. +// The m dimension (dim 1) is tiled, so the index op for m gets an offset. +// The k2 dimension (dim 3) is not tiled, so no offset is applied. +// CHECK-LABEL: func.func @attention_causal +// CHECK: scf.for +// CHECK: scf.for +// CHECK: iree_linalg_ext.attention +// CHECK: ^bb0(%[[SCORE:.+]]: f32): +// CHECK: %[[M_RAW:.+]] = iree_linalg_ext.index 1 : index +// CHECK: %[[M_OFF:.+]] = affine.apply +// CHECK: %[[K2:.+]] = iree_linalg_ext.index 3 : index +// CHECK: arith.cmpi ugt, %[[K2]], %[[M_OFF]] : index +// CHECK: arith.select +// CHECK: iree_linalg_ext.yield + +// ----- + func.func @attention_float_mask(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>, %mask: tensor<192x1024x1024xf32>) -> tensor<192x1024x64xf32> { %0 = tensor.empty() : tensor<192x1024x64xf32> %scale = arith.constant 1.0 : f32