diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td index 99db283516a5..19bcb1fe01e1 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td @@ -262,15 +262,24 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention", Attention is defined as matmul(softmax(matmul(Q, transpose(K))+M), V) and has shape BxMxN. Usually, this operator also performs scaling, masking and dropout, but we leave that out of the current implementation. + + When `is_causal` is true, the attention mask operand is a materialized + causal (lower-triangular) mask. Downstream consumers may use this flag + to replace the mask with a fused index computation. }]; let arguments = (ins Variadic:$inputs, - Variadic:$outputs + Variadic:$outputs, + OptionalAttr:$is_causal ); let builders = [ - OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs), [{ - build($_builder, $_state, TypeRange(outputs), inputs, outputs); + OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs, + CArg<"std::optional", "std::nullopt">:$isCausal), [{ + build($_builder, $_state, TypeRange(outputs), inputs, outputs, + isCausal.has_value() + ? $_builder.getBoolAttr(*isCausal) + : BoolAttr()); }]> ]; diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 84ff70a8ab53..fa797005683d 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -2039,9 +2039,12 @@ class ConvertAtenScaledDotProductAttentionOp } // Overwrite with tm_tensor::attention - Value attention = AttentionOp::create(rewriter, loc, outType, inputs, - SmallVector{output}) - .getResult()[0]; + std::optional isCausalOpt = + causal ? std::optional(true) : std::nullopt; + Value attention = + AttentionOp::create(rewriter, loc, inputs, SmallVector{output}, + isCausalOpt) + .getResult()[0]; if (opTy != outType) { attention = tensor::ExpandShapeOp::create(rewriter, loc, opTy, attention,