Skip to content

Commit d493934

Browse files
committed
Add with_attn_mask flag to conditionally apply attention mask
1 parent e4009f1 commit d493934

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

src/ntops/kernels/scaled_dot_product_attention.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def arrangement(
1919
attn_mask,
2020
scale,
2121
output,
22+
with_attn_mask,
2223
with_kv_cache,
2324
BLOCK_SIZE_M=BLOCK_SIZE_M,
2425
BLOCK_SIZE_N=BLOCK_SIZE_N,
@@ -68,6 +69,7 @@ def arrange_attn_mask(input):
6869
attn_mask_arranged = arrange_attn_mask(attn_mask)
6970
scale_arranged = scale
7071
output_arranged = arrange_query_or_output(output)
72+
with_attn_mask_arranged = with_attn_mask
7173

7274
if with_kv_cache:
7375
return (
@@ -81,6 +83,7 @@ def arrange_attn_mask(input):
8183
attn_mask_arranged,
8284
scale_arranged,
8385
output_arranged,
86+
with_attn_mask_arranged,
8487
)
8588

8689
return (
@@ -90,6 +93,7 @@ def arrange_attn_mask(input):
9093
attn_mask_arranged,
9194
scale_arranged,
9295
output_arranged,
96+
with_attn_mask_arranged,
9397
)
9498

9599

@@ -104,14 +108,19 @@ def application_with_kv_cache(
104108
attn_mask,
105109
scale,
106110
output,
111+
with_attn_mask,
107112
):
108113
present_key_slot = present_key # noqa: F841
109114
present_value_slot = present_value # noqa: F841
110115

111-
application_without_kv_cache(query, key, value, attn_mask, scale, output)
116+
application_without_kv_cache(
117+
query, key, value, attn_mask, scale, output, with_attn_mask
118+
)
112119

113120

114-
def application_without_kv_cache(query, key, value, attn_mask, scale, output):
121+
def application_without_kv_cache(
122+
query, key, value, attn_mask, scale, output, with_attn_mask
123+
):
115124
for i in range(query.shape[0]):
116125
query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype)
117126

@@ -120,9 +129,12 @@ def application_without_kv_cache(query, key, value, attn_mask, scale, output):
120129
max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32)
121130

122131
for j in range(key.shape[0]):
123-
qk = ntl.dot(query_i, ntl.trans(key[j])) + attn_mask[j]
132+
qk = ntl.dot(query_i, ntl.trans(key[j]))
124133
qk = ntl.where(key[j].offsets(-2) < key.source.shape[-2], qk, float("-inf"))
125134

135+
if with_attn_mask:
136+
qk += attn_mask[j]
137+
126138
next_max = ntl.maximum(max, ntl.max(qk, 1))
127139
stable_qk = ntl.exp2(qk - next_max[:, None])
128140

@@ -156,6 +168,7 @@ def make(with_kv_cache):
156168
for _ in range(4)
157169
)
158170
scale = Tensor(0)
171+
with_attn_mask = Tensor(0, constexpr=True)
159172

160173
if with_kv_cache:
161174
application = application_with_kv_cache
@@ -173,6 +186,7 @@ def make(with_kv_cache):
173186
attn_mask,
174187
scale,
175188
output,
189+
with_attn_mask,
176190
)
177191

178192
return ninetoothed.make(

src/ntops/torch.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -342,12 +342,17 @@ def scaled_dot_product_attention(
342342

343343
mask_shape = query.shape[:-1] + (key.shape[-2],)
344344

345-
if attn_mask is None:
346-
attn_mask = torch.zeros(mask_shape, dtype=query.dtype, device=query.device)
347-
elif attn_mask.dtype == torch.bool:
348-
attn_mask = torch.where(attn_mask, 0, float("-inf"))
345+
if attn_mask is not None:
346+
with_attn_mask = True
349347

350-
attn_mask = attn_mask.expand(mask_shape)
348+
if attn_mask.dtype == torch.bool:
349+
attn_mask = torch.where(attn_mask, 0, float("-inf"))
350+
351+
attn_mask = attn_mask.expand(mask_shape)
352+
else:
353+
with_attn_mask = False
354+
355+
attn_mask = torch.empty(mask_shape, device="meta")
351356

352357
if scale is None:
353358
scale = 1 / math.sqrt(query.shape[-1])
@@ -373,9 +378,10 @@ def scaled_dot_product_attention(
373378
attn_mask,
374379
scale,
375380
output,
381+
with_attn_mask,
376382
)
377383
else:
378-
kernel(query, key, value, attn_mask, scale, output)
384+
kernel(query, key, value, attn_mask, scale, output, with_attn_mask)
379385

380386
return output
381387

0 commit comments

Comments
 (0)