Skip to content

Commit fc1db87

Browse files
committed
complete the qk clip on transformer wrapper / attention layers for muon training
1 parent 349f030 commit fc1db87

File tree

3 files changed

+41
-2
lines changed

3 files changed

+41
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "x-transformers"
3-
version = "2.7.4"
3+
version = "2.7.6"
44
description = "X-Transformers"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_x_transformers.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1315,7 +1315,7 @@ def test_simple_mdlm(
13151315
loss = nar(seq)
13161316
loss.loss.backward()
13171317

1318-
def test_qk_clip():
1318+
def test_qk_clip_attn():
13191319
from x_transformers import Attention
13201320

13211321
x = torch.randn(1, 1024, 512)
@@ -1325,3 +1325,18 @@ def test_qk_clip():
13251325
out, intermediates = attn(x, return_intermediates = True)
13261326

13271327
attn.qk_clip_(intermediates, tau = 100)
1328+
1329+
def test_qk_clip_attn_layers():
1330+
from x_transformers import TransformerWrapper, Decoder
1331+
1332+
model = TransformerWrapper(
1333+
num_tokens = 256,
1334+
max_seq_len = 1024,
1335+
attn_layers = Decoder(dim = 512, depth = 2)
1336+
)
1337+
1338+
seq = torch.randint(0, 256, (1, 1024))
1339+
1340+
out, intermediates = model(seq, return_intermediates = True)
1341+
1342+
model.attn_qk_clip_(intermediates)

x_transformers/x_transformers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2462,6 +2462,23 @@ def __init__(
24622462

24632463
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
24642464

2465+
def attn_qk_clip_(
2466+
self,
2467+
intermediates: LayerIntermediates,
2468+
tau = 100.
2469+
):
2470+
# pairs up the attention intermediates with each attention module and does qk clip proposed by kimi team
2471+
2472+
layer_and_layer_types = (self.layers, self.layer_types)
2473+
2474+
attn_layers = [layer for (_, layer, _), layer_type in zip(self.layers, self.layer_types) if layer_type in ('a', 'c')]
2475+
attn_intermeds = intermediates.attn_intermediates
2476+
2477+
assert len(attn_layers) == len(attn_intermeds)
2478+
2479+
for attn_layer, attn_inter in zip(attn_layers, attn_intermeds):
2480+
attn_layer.qk_clip_(attn_inter, tau = tau)
2481+
24652482
def forward(
24662483
self,
24672484
x,
@@ -3192,6 +3209,13 @@ def init_(self):
31923209
if not isinstance(self.pos_emb, always):
31933210
nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
31943211

3212+
def attn_qk_clip_(
3213+
self,
3214+
intermediates: LayerIntermediates,
3215+
tau = 100.
3216+
):
3217+
self.attn_layers.attn_qk_clip_(intermediates, tau = tau)
3218+
31953219
def forward(
31963220
self,
31973221
x,

0 commit comments

Comments
 (0)