Skip to content

Commit

Permalink
Merge pull request ReaLLMASIC#137 from gkielian/add_parallel_mlp_attn…
Browse files Browse the repository at this point in the history
…_option

Add parallel mlp attn option
  • Loading branch information
klei22 authored Apr 10, 2024
2 parents 71c3df9 + 6b4c05c commit 948a8bd
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 10 deletions.
20 changes: 20 additions & 0 deletions explorations/mlp_par_sweep.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"max_iters": ["3500"],
"n_layer": ["6"],
"n_head": ["6"],
"n_embd": ["384"],
"block_size":["256"],
"use_post_ln": [true, false],
"use_parallel_mlp": [true, false],
"device": ["cuda"],
"dtype": ["bfloat16"],
"dataset": ["shakespeare_char"],
"use_rotary_embeddings": [false],
"use_abs_pos_embeddings": [true],
"compile": [true],
"softmax_variant_attn": ["softmax", "polymax", "saturatingconsmax"],
"tensorboard_run_name": ["mlp_parallelization"]
}
]

32 changes: 22 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,33 +268,43 @@ def __init__(self, config, mlp=None, attn=None):

if config.layernorm_variant == 'rmsnorm':
self.ln_1 = RMSNorm(config.n_embd)
self.ln_2 = RMSNorm(config.n_embd)
if not config.use_parallel_mlp:
self.ln_2 = RMSNorm(config.n_embd)

if config.layernorm_variant == 'layernorm':
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
if not config.use_parallel_mlp:
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)

self.use_post_ln = config.use_post_ln
self.use_parallel_mlp = config.use_parallel_mlp

# Allow for sharing attn between blocks
if attn == None:
self.attn = CausalSelfAttention(config)
self.attn = CausalSelfAttention(config)
else:
self.attn = attn
self.attn = attn

# Allow for sharing mlp between blocks
if mlp == None:
self.mlp = MLP(config)
self.mlp = MLP(config)
else:
self.mlp = mlp
self.mlp = mlp

def forward(self, x):
if self.use_post_ln:
x = self.ln_1(x + self.attn(x))
x = self.ln_2(x + self.mlp(x))
if self.use_parallel_mlp:
x = self.ln_1(x + self.attn(x) + self.mlp(x))
else:
x = self.ln_1(x + self.attn(x))
x = self.ln_2(x + self.mlp(x))
else:
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
if self.use_parallel_mlp:
ln_1 = self.ln_1(x)
x = x + self.attn(ln_1) + self.mlp(ln_1)
else:
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x

@dataclass
Expand All @@ -309,6 +319,8 @@ class GPTConfig:
window_size: int = 128
gate: bool = False

use_parallel_mlp: bool = False

# Shared parameters
# MLP
shared_mlp_size: int = 1
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def parse_args():
model_group.add_argument('--n_embd', default=384, type=int)
model_group.add_argument('--dropout', default=0.2, type=float)
model_group.add_argument('--use_post_ln', default=False, action=argparse.BooleanOptionalAction)
model_group.add_argument('--use_parallel_mlp', default=False, action=argparse.BooleanOptionalAction)
model_group.add_argument('--window_size', default=None, type=int, help="Sliding window size, note this cannot be greater than block size")
model_group.add_argument('--gate', default=False, action=argparse.BooleanOptionalAction, help="option for gated attention see https://arxiv.org/abs/2306.12929")

Expand Down

0 comments on commit 948a8bd

Please sign in to comment.