From c509d8427c4918ea3a8a69022543b646974442c0 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 2 Dec 2024 11:20:28 -0800 Subject: [PATCH] Fix backward pass on causal Summary: For the Triton kernel, only causal is implemented on the backward pass (see details at D66574768). FA3 backward implements both causal and non-causal correctly. Reviewed By: manman-ren, sijiac Differential Revision: D66574710 fbshipit-source-id: 5117ea6c25066ca91aeed241e0e7a0dd2ae51b5d --- tritonbench/kernels/triton_fused_attention.py | 2 ++ tritonbench/operators/flash_attention/operator.py | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 9348474b..4946812f 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -1898,6 +1898,8 @@ def grid_tma(META): @staticmethod def backward(ctx, do): + if not ctx.causal: + raise NotImplementedError("only causal backward is implemented on Triton") q, k, v, o, M = ctx.saved_tensors assert do.is_contiguous() assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 539a4920..f057ad9f 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -144,7 +144,11 @@ def parse_op_args(args: List[str]): parser.add_argument("--seq-len", type=int, default=11, help="Batch size") parser.add_argument("--n-heads", type=int, default=48, help="Number of heads") parser.add_argument("--d-head", type=int, default=64, help="specify head dimension") - parser.add_argument("--causal", action="store_true", help="enable causal") + parser.add_argument( + "--causal", + action="store_true", + help="enable causal (always true on backward)", + ) return parser.parse_args(args) @@ -164,6 +168,10 @@ def __init__( self.D_HEAD = args.d_head self.N_CTX = None self.causal = args.causal + # We always turn on causal for backward + # Because Triton-Flash-V2 does not support backward with non-causal + if self.mode == BenchmarkMode.BWD or self.mode == BenchmarkMode.FWD_BWD: + self.causal = True self.sm_scale = 1.3 @register_benchmark()