Skip to content

Commit

Permalink
Fix backward pass on causal
Browse files Browse the repository at this point in the history
Summary:
For the Triton kernel, only causal is implemented on the backward pass (see details at D66574768).

FA3 backward implements both causal and non-causal correctly.

Reviewed By: manman-ren, sijiac

Differential Revision: D66574710

fbshipit-source-id: 5117ea6c25066ca91aeed241e0e7a0dd2ae51b5d
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Dec 2, 2024
1 parent deb9183 commit c509d84
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 2 additions & 0 deletions tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,8 @@ def grid_tma(META):

@staticmethod
def backward(ctx, do):
if not ctx.causal:
raise NotImplementedError("only causal backward is implemented on Triton")
q, k, v, o, M = ctx.saved_tensors
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
Expand Down
10 changes: 9 additions & 1 deletion tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ def parse_op_args(args: List[str]):
parser.add_argument("--seq-len", type=int, default=11, help="Batch size")
parser.add_argument("--n-heads", type=int, default=48, help="Number of heads")
parser.add_argument("--d-head", type=int, default=64, help="specify head dimension")
parser.add_argument("--causal", action="store_true", help="enable causal")
parser.add_argument(
"--causal",
action="store_true",
help="enable causal (always true on backward)",
)
return parser.parse_args(args)


Expand All @@ -164,6 +168,10 @@ def __init__(
self.D_HEAD = args.d_head
self.N_CTX = None
self.causal = args.causal
# We always turn on causal for backward
# Because Triton-Flash-V2 does not support backward with non-causal
if self.mode == BenchmarkMode.BWD or self.mode == BenchmarkMode.FWD_BWD:
self.causal = True
self.sm_scale = 1.3

@register_benchmark()
Expand Down

0 comments on commit c509d84

Please sign in to comment.