4
4
from typing import Tuple , Callable
5
5
6
6
import torch
7
- from torch .nn import Module
8
- from torch import nn , einsum , Tensor
7
+ from torch .nn import Module , Parameter
8
+ from torch import cat , nn , einsum , Tensor
9
9
import torch .nn .functional as F
10
10
11
11
from collections import namedtuple
@@ -176,6 +176,7 @@ def __init__(
176
176
softclamp_logits = False ,
177
177
logit_softclamp_value = 50. ,
178
178
add_zero_kv = False ,
179
+ head_learned_sink = False ,
179
180
selective = False ,
180
181
hard = False ,
181
182
cope = None ,
@@ -254,6 +255,13 @@ def __init__(
254
255
255
256
self .add_zero_kv = add_zero_kv
256
257
258
+ # learned sink concatted pre-softmax, working solution from gpt-oss
259
+
260
+ assert not (head_learned_sink and flash ), f'not supported for flash attention yet'
261
+
262
+ self .head_learned_sink = head_learned_sink
263
+ self .head_attn_sink = Parameter (torch .zeros (heads )) if head_learned_sink else None
264
+
257
265
# soft clamp attention logit value
258
266
259
267
if softclamp_logits :
@@ -315,10 +323,10 @@ def flash_attn(
315
323
if self .l2_distance :
316
324
k_norm_sq = k .norm (dim = - 1 , keepdim = True ) ** 2
317
325
k = F .pad (k , (0 , 1 ), value = - 1. )
318
- k = torch . cat ((k , k_norm_sq ), dim = - 1 )
326
+ k = cat ((k , k_norm_sq ), dim = - 1 )
319
327
320
328
q_norm_sq = q .norm (dim = - 1 , keepdim = True ) ** 2
321
- q = torch . cat ((2 * q , q_norm_sq ), dim = - 1 )
329
+ q = cat ((2 * q , q_norm_sq ), dim = - 1 )
322
330
q = F .pad (q , (0 , 1 ), value = - 1. )
323
331
324
332
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
@@ -509,6 +517,11 @@ def forward(
509
517
if self .selective :
510
518
sim = selective_attn (sim )
511
519
520
+ if self .head_learned_sink :
521
+ # add learned attention sink
522
+ attn_sink = repeat (self .head_attn_sink , 'h -> b h i 1' , b = sim .shape [0 ], i = sim .shape [2 ])
523
+ sim = cat ((attn_sink , sim ), dim = - 1 )
524
+
512
525
pre_softmax_attn = sim
513
526
514
527
attn = self .attn_fn (sim )
@@ -517,6 +530,10 @@ def forward(
517
530
518
531
post_softmax_attn = attn
519
532
533
+ if self .head_learned_sink :
534
+ # remove attention sink
535
+ attn = attn [..., 1 :]
536
+
520
537
attn = self .attn_dropout (attn )
521
538
522
539
if exists (self .post_softmax_talking_heads ):
0 commit comments