Skip to content

Commit 57119f0

Browse files
committed
Add support for is_causal
1 parent d493934 commit 57119f0

File tree

3 files changed

+30
-9
lines changed

3 files changed

+30
-9
lines changed

src/ntops/kernels/scaled_dot_product_attention.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def arrangement(
1717
present_key_slot,
1818
present_value_slot,
1919
attn_mask,
20+
is_causal,
2021
scale,
2122
output,
2223
with_attn_mask,
@@ -67,6 +68,7 @@ def arrange_attn_mask(input):
6768
present_value_slot
6869
)
6970
attn_mask_arranged = arrange_attn_mask(attn_mask)
71+
is_causal_arranged = is_causal
7072
scale_arranged = scale
7173
output_arranged = arrange_query_or_output(output)
7274
with_attn_mask_arranged = with_attn_mask
@@ -81,6 +83,7 @@ def arrange_attn_mask(input):
8183
present_key_slot_arranged,
8284
present_value_slot_arranged,
8385
attn_mask_arranged,
86+
is_causal_arranged,
8487
scale_arranged,
8588
output_arranged,
8689
with_attn_mask_arranged,
@@ -91,6 +94,7 @@ def arrange_attn_mask(input):
9194
key_arranged,
9295
value_arranged,
9396
attn_mask_arranged,
97+
is_causal_arranged,
9498
scale_arranged,
9599
output_arranged,
96100
with_attn_mask_arranged,
@@ -106,6 +110,7 @@ def application_with_kv_cache(
106110
present_key_slot,
107111
present_value_slot,
108112
attn_mask,
113+
is_causal,
109114
scale,
110115
output,
111116
with_attn_mask,
@@ -114,12 +119,12 @@ def application_with_kv_cache(
114119
present_value_slot = present_value # noqa: F841
115120

116121
application_without_kv_cache(
117-
query, key, value, attn_mask, scale, output, with_attn_mask
122+
query, key, value, attn_mask, is_causal, scale, output, with_attn_mask
118123
)
119124

120125

121126
def application_without_kv_cache(
122-
query, key, value, attn_mask, scale, output, with_attn_mask
127+
query, key, value, attn_mask, is_causal, scale, output, with_attn_mask
123128
):
124129
for i in range(query.shape[0]):
125130
query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype)
@@ -135,6 +140,10 @@ def application_without_kv_cache(
135140
if with_attn_mask:
136141
qk += attn_mask[j]
137142

143+
if is_causal:
144+
mask = query[i].offsets(-2)[:, None] >= key[j].offsets(-2)[None, :]
145+
qk = ntl.where(mask, qk, float("-inf"))
146+
138147
next_max = ntl.maximum(max, ntl.max(qk, 1))
139148
stable_qk = ntl.exp2(qk - next_max[:, None])
140149

@@ -168,7 +177,7 @@ def make(with_kv_cache):
168177
for _ in range(4)
169178
)
170179
scale = Tensor(0)
171-
with_attn_mask = Tensor(0, constexpr=True)
180+
is_causal, with_attn_mask = (Tensor(0, constexpr=True) for _ in range(2))
172181

173182
if with_kv_cache:
174183
application = application_with_kv_cache
@@ -184,6 +193,7 @@ def make(with_kv_cache):
184193
present_key_slot,
185194
present_value_slot,
186195
attn_mask,
196+
is_causal,
187197
scale,
188198
output,
189199
with_attn_mask,

src/ntops/torch.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,10 +336,12 @@ def scaled_dot_product_attention(
336336
):
337337
# TODO: Support `dropout_p`.
338338
assert dropout_p == 0, "`dropout_p` is not supported yet."
339-
# TODO: Support `is_causal`.
340-
assert not is_causal, "`is_causal` is not supported yet."
341339
assert enable_gqa, "GQA must be enabled for now."
342340

341+
assert attn_mask is None or not is_causal, (
342+
"Cannot use `attn_mask` and `is_causal` together."
343+
)
344+
343345
mask_shape = query.shape[:-1] + (key.shape[-2],)
344346

345347
if attn_mask is not None:
@@ -376,12 +378,13 @@ def scaled_dot_product_attention(
376378
present_key_slot,
377379
present_value_slot,
378380
attn_mask,
381+
is_causal,
379382
scale,
380383
output,
381384
with_attn_mask,
382385
)
383386
else:
384-
kernel(query, key, value, attn_mask, scale, output, with_attn_mask)
387+
kernel(query, key, value, attn_mask, is_causal, scale, output, with_attn_mask)
385388

386389
return output
387390

tests/test_scaled_dot_product_attention.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@ def _generate_random_size():
1717
arguments = []
1818

1919
attn_mask_types = (None, torch.bool, torch.float32)
20+
is_causal_values = (False, True)
2021
scales = (None, random.uniform(0.05, 0.5))
2122
dtypes = (torch.float32, torch.float16)
2223
with_kv_cache_values = (False, True)
2324

24-
for attn_mask_type, scale, dtype, with_kv_cache in itertools.product(
25-
attn_mask_types, scales, dtypes, with_kv_cache_values
25+
for attn_mask_type, is_causal, scale, dtype, with_kv_cache in itertools.product(
26+
attn_mask_types, is_causal_values, scales, dtypes, with_kv_cache_values
2627
):
28+
if attn_mask_type is not None and is_causal:
29+
continue
30+
2731
batch_size = random.randint(1, 4)
2832
num_heads_q = 2 ** random.randint(1, 5)
2933
seq_len_q = _generate_random_size()
@@ -49,6 +53,7 @@ def _generate_random_size():
4953
num_heads_kv,
5054
seq_len_kv,
5155
attn_mask_type,
56+
is_causal,
5257
scale,
5358
enable_gqa,
5459
with_kv_cache,
@@ -59,7 +64,7 @@ def _generate_random_size():
5964
)
6065

6166
return (
62-
"batch_size, num_heads_q, seq_len_q, head_dim, num_heads_kv, seq_len_kv, attn_mask_type, scale, enable_gqa, with_kv_cache, dtype, atol, rtol",
67+
"batch_size, num_heads_q, seq_len_q, head_dim, num_heads_kv, seq_len_kv, attn_mask_type, is_causal, scale, enable_gqa, with_kv_cache, dtype, atol, rtol",
6368
arguments,
6469
)
6570

@@ -74,6 +79,7 @@ def test_cuda(
7479
num_heads_kv,
7580
seq_len_kv,
7681
attn_mask_type,
82+
is_causal,
7783
scale,
7884
enable_gqa,
7985
with_kv_cache,
@@ -129,6 +135,7 @@ def _generate_present_and_slot(tensor):
129135
key,
130136
value,
131137
attn_mask=attn_mask,
138+
is_causal=is_causal,
132139
scale=scale,
133140
enable_gqa=enable_gqa,
134141
present_key=present_key,
@@ -141,6 +148,7 @@ def _generate_present_and_slot(tensor):
141148
key_cloned,
142149
value_cloned,
143150
attn_mask=attn_mask,
151+
is_causal=is_causal,
144152
scale=scale,
145153
enable_gqa=enable_gqa,
146154
)

0 commit comments

Comments
 (0)