Skip to content

Commit 72e6107

Browse files
committed
complete the qk clip on transformer wrapper / attention layers for muon training
1 parent 349f030 commit 72e6107

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "x-transformers"
3-
version = "2.7.4"
3+
version = "2.7.5"
44
description = "X-Transformers"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_x_transformers.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1315,7 +1315,7 @@ def test_simple_mdlm(
13151315
loss = nar(seq)
13161316
loss.loss.backward()
13171317

1318-
def test_qk_clip():
1318+
def test_qk_clip_attn():
13191319
from x_transformers import Attention
13201320

13211321
x = torch.randn(1, 1024, 512)
@@ -1325,3 +1325,18 @@ def test_qk_clip():
13251325
out, intermediates = attn(x, return_intermediates = True)
13261326

13271327
attn.qk_clip_(intermediates, tau = 100)
1328+
1329+
def test_qk_clip_attn_layers():
1330+
from x_transformers import TransformerWrapper, Decoder
1331+
1332+
model = TransformerWrapper(
1333+
num_tokens = 256,
1334+
max_seq_len = 1024,
1335+
attn_layers = Decoder(dim = 512, depth = 2)
1336+
)
1337+
1338+
seq = torch.randint(0, 256, (1, 1024))
1339+
1340+
out, intermediates = model(seq, return_intermediates = True)
1341+
1342+
model.attn_qk_clip_(intermediates)

x_transformers/x_transformers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2462,6 +2462,20 @@ def __init__(
24622462

24632463
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
24642464

2465+
def attn_qk_clip_(
2466+
self,
2467+
intermediates: LayerIntermediates,
2468+
tau = 100.
2469+
):
2470+
# pairs up the attention intermediates with each attention module and does qk clip proposed by kimi team
2471+
2472+
for (_, layer, _), layer_type, attn_inter in zip(self.layers, self.layer_types, intermediates.attn_intermediates):
2473+
2474+
if layer_type not in ('a', 'c'):
2475+
continue
2476+
2477+
layer.qk_clip_(attn_inter, tau = tau)
2478+
24652479
def forward(
24662480
self,
24672481
x,
@@ -3192,6 +3206,13 @@ def init_(self):
31923206
if not isinstance(self.pos_emb, always):
31933207
nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
31943208

3209+
def attn_qk_clip_(
3210+
self,
3211+
intermediates: LayerIntermediates,
3212+
tau = 100.
3213+
):
3214+
self.attn_layers.attn_qk_clip_(intermediates, tau = tau)
3215+
31953216
def forward(
31963217
self,
31973218
x,

0 commit comments

Comments
 (0)