Skip to content

Commit 31dc0af

Browse files
committed
allow for more than one learned head attn sink
1 parent 9e4f870 commit 31dc0af

File tree

4 files changed

+13
-12
lines changed

4 files changed

+13
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "x-transformers"
3-
version = "2.6.4"
3+
version = "2.6.5"
44
description = "X-Transformers"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_x_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1245,7 +1245,7 @@ def test_learned_head_attn_sink():
12451245
dim = 512,
12461246
depth = 12,
12471247
heads = 8,
1248-
attn_head_learned_sink = True
1248+
attn_head_learned_sinks = 4
12491249
)
12501250
)
12511251

x_transformers/attend.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def __init__(
176176
softclamp_logits = False,
177177
logit_softclamp_value = 50.,
178178
add_zero_kv = False,
179-
head_learned_sink = False,
179+
head_learned_sinks = 0,
180180
selective = False,
181181
hard = False,
182182
cope = None,
@@ -257,10 +257,10 @@ def __init__(
257257

258258
# learned sink concatted pre-softmax, working solution from gpt-oss
259259

260-
assert not (head_learned_sink and flash), f'not supported for flash attention yet'
260+
self.has_head_learned_sinks = head_learned_sinks > 0
261+
assert not (self.has_head_learned_sinks and flash), f'not supported for flash attention yet'
261262

262-
self.head_learned_sink = head_learned_sink
263-
self.head_attn_sink = Parameter(torch.zeros(heads)) if head_learned_sink else None
263+
self.head_attn_sinks = Parameter(torch.zeros(heads, head_learned_sinks)) if self.has_head_learned_sinks else None
264264

265265
# soft clamp attention logit value
266266

@@ -517,9 +517,10 @@ def forward(
517517
if self.selective:
518518
sim = selective_attn(sim)
519519

520-
if self.head_learned_sink:
520+
if self.has_head_learned_sinks:
521521
# add learned attention sink
522-
attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
522+
num_sinks = self.head_attn_sinks.shape[-1]
523+
attn_sink = repeat(self.head_attn_sinks, 'h sinks -> b h i sinks', b = sim.shape[0], i = sim.shape[2])
523524
sim = cat((attn_sink, sim), dim = -1)
524525

525526
pre_softmax_attn = sim
@@ -530,9 +531,9 @@ def forward(
530531

531532
post_softmax_attn = attn
532533

533-
if self.head_learned_sink:
534+
if self.has_head_learned_sinks:
534535
# remove attention sink
535-
attn = attn[..., 1:]
536+
attn = attn[..., num_sinks:]
536537

537538
attn = self.attn_dropout(attn)
538539

x_transformers/x_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,7 +1319,7 @@ def __init__(
13191319
value_dim_head = None,
13201320
dim_out = None,
13211321
add_zero_kv = False, # same as add_zero_attn in pytorch
1322-
head_learned_sink = False,
1322+
head_learned_sinks = 0,
13231323
rotate_num_heads = None,
13241324
data_dependent_alibi = False,
13251325
data_dependent_alibi_per_row = False,
@@ -1516,7 +1516,7 @@ def __init__(
15161516
selective = selective,
15171517
custom_attn_fn = custom_attn_fn,
15181518
add_zero_kv = add_zero_kv,
1519-
head_learned_sink = head_learned_sink,
1519+
head_learned_sinks = head_learned_sinks,
15201520
flash = flash,
15211521
softclamp_logits = softclamp_logits,
15221522
logit_softclamp_value = logit_softclamp_value,

0 commit comments

Comments
 (0)