From b2417f229cf29d9e8664b20d9aa84cab9790f605 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Thu, 21 Mar 2024 03:26:08 +0000 Subject: [PATCH 1/7] Geglu like griffin. --- open_lm/model.py | 4 ++-- open_lm/model_configs/open_lm_1b_geglu.json | 10 ++++++++++ open_lm/model_configs/open_lm_7b_geglu.json | 10 ++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 open_lm/model_configs/open_lm_1b_geglu.json create mode 100644 open_lm/model_configs/open_lm_7b_geglu.json diff --git a/open_lm/model.py b/open_lm/model.py index 9f5c6397..540c4133 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -243,8 +243,8 @@ def __init__(self, layer_id, args: Params): self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False) self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) elif args.ffn_type == "gemma_geglu": - # this follows llama / lit llama -- go to multiple of 256 - self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) + # this is from the Griffin paper - hidden_dim is always 3 * dim + self.hidden_dim = 3 * args.dim self.feed_forward = GemmaMLP(args.dim, self.hidden_dim, layer_id) elif args.ffn_type == "moe": moe_args = MoEArgs( diff --git a/open_lm/model_configs/open_lm_1b_geglu.json b/open_lm/model_configs/open_lm_1b_geglu.json new file mode 100644 index 00000000..daa90523 --- /dev/null +++ b/open_lm/model_configs/open_lm_1b_geglu.json @@ -0,0 +1,10 @@ +{ + "hidden_dim": 2048, + "n_layers": 23, + "n_heads": 16, + "seq_len": 2048, + "vocab_size": 50432, + "post_embed_norm": false, + "weight_tying": false, + "ffn_type": "gemma_geglu" +} diff --git a/open_lm/model_configs/open_lm_7b_geglu.json b/open_lm/model_configs/open_lm_7b_geglu.json new file mode 100644 index 00000000..9763dccc --- /dev/null +++ b/open_lm/model_configs/open_lm_7b_geglu.json @@ -0,0 +1,10 @@ +{ + "hidden_dim": 4096, + "n_layers": 30, + "n_heads": 32, + "seq_len": 2048, + "vocab_size": 50432, + "post_embed_norm": false, + "weight_tying": false, + "ffn_type": "gemma_geglu" +} From c5b6e6576f9a29160fc974bc039215775d3f3c30 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Mon, 25 Mar 2024 20:06:56 -0500 Subject: [PATCH 2/7] Revert to original sizes and add MQA. --- open_lm/attention.py | 2 +- open_lm/model.py | 28 ++++++++++++++++++--- open_lm/model_configs/open_lm_1b_geglu.json | 2 +- open_lm/model_configs/open_lm_7b_geglu.json | 2 +- open_lm/params.py | 2 +- 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index f134786c..ecde2fb1 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -164,7 +164,7 @@ def get_attn_func( ): if attn_name == "auto": return xformers_attn if torch.cuda.is_available() else torch_attn - elif attn_name == "xformers_attn": + elif attn_name == "xformers_attn" or attn_name == "xformers_mqa": return xformers_attn elif attn_name == "xformers_attn_variable_length": # Upon changing the input sequence length, xformers attention changes diff --git a/open_lm/model.py b/open_lm/model.py index 540c4133..c6a378cc 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -99,6 +99,7 @@ class Params: moe_freq: int = 0 positional_embedding_type: str = "rotary" ffn_type: str = "swiglu" + mqa: bool = False def get_pos_embed(args: Params): @@ -120,11 +121,16 @@ def __init__(self, layer_id, args: Params): super().__init__() self.n_heads = args.n_heads self.head_dim = args.dim // args.n_heads - self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) + self.mqa = args.mqa + if not self.mqa: + self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) + else: + self.in_proj = nn.Linear(args.dim, (args.n_heads + 2)* self.head_dim, bias=False) self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) self.pos_embed = get_pos_embed(args) self.attn_fn = args.attn_func self.apply_qk_norm = args.apply_qk_norm + # initialize norm layers for queries and keys if needed self.q_norm = ( @@ -158,14 +164,27 @@ def reset_parameters(self): def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False, attention_mask=None): batchsize, q_len, _ = x.shape - queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) + if not self.mqa: + queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) + else: + print("HI") + qkv = self.in_proj(x) + queries = qkv[..., :-2 * self.head_dim] + keys = qkv[..., -2 * self.head_dim : - self.head_dim] + vals = qkv[..., - self.head_dim :] queries = self.q_norm(queries) keys = self.k_norm(keys) queries = queries.view(batchsize, q_len, self.n_heads, self.head_dim) - keys = keys.view(batchsize, q_len, self.n_heads, self.head_dim) - vals = vals.view(batchsize, q_len, self.n_heads, self.head_dim) + if not self.mqa: + keys = keys.view(batchsize, q_len, self.n_heads, self.head_dim) + vals = vals.view(batchsize, q_len, self.n_heads, self.head_dim) + else: + keys = keys.view(batchsize, q_len, 1, self.head_dim) + vals = vals.view(batchsize, q_len, 1, self.head_dim) + keys = keys.expand(-1, -1, self.n_heads, -1) + vals = keys.expand(-1, -1, self.n_heads, -1) past_length = 0 if past_key_value is None else past_key_value[0].shape[1] queries, keys, vals = self.pos_embed(queries, keys, vals, offset=past_length) @@ -446,6 +465,7 @@ def create_params(args): moe_capacity_factor=cfg.get("moe_capacity_factor", args.moe_capacity_factor), moe_freq=cfg.get("moe_freq", args.moe_freq), moe_top_k=cfg.get("moe_top_k", args.moe_top_k), + mqa="mqa" in args.attn_name ) diff --git a/open_lm/model_configs/open_lm_1b_geglu.json b/open_lm/model_configs/open_lm_1b_geglu.json index daa90523..fce2fb17 100644 --- a/open_lm/model_configs/open_lm_1b_geglu.json +++ b/open_lm/model_configs/open_lm_1b_geglu.json @@ -1,6 +1,6 @@ { "hidden_dim": 2048, - "n_layers": 23, + "n_layers": 24, "n_heads": 16, "seq_len": 2048, "vocab_size": 50432, diff --git a/open_lm/model_configs/open_lm_7b_geglu.json b/open_lm/model_configs/open_lm_7b_geglu.json index 9763dccc..1f6ba189 100644 --- a/open_lm/model_configs/open_lm_7b_geglu.json +++ b/open_lm/model_configs/open_lm_7b_geglu.json @@ -1,6 +1,6 @@ { "hidden_dim": 4096, - "n_layers": 30, + "n_layers": 32, "n_heads": 32, "seq_len": 2048, "vocab_size": 50432, diff --git a/open_lm/params.py b/open_lm/params.py index 2543fd10..ee363d4d 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -106,7 +106,7 @@ def add_model_args(parser): "--attn-name", type=str, default="auto", - choices=["auto", "xformers_attn", "xformers_attn_variable_length", "torch_attn", "custom_attn"], + choices=["auto", "xformers_attn", "xformers_attn_variable_length", "torch_attn", "xformers_mqa", "custom_attn"], help="type of attention to use", ) parser.add_argument( From 65e7c10d93336ccdfc4d67210b04c7e9c6091053 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Mon, 25 Mar 2024 23:02:31 -0500 Subject: [PATCH 3/7] Remove debug msg. --- open_lm/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/open_lm/model.py b/open_lm/model.py index c6a378cc..114b7fba 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -167,7 +167,6 @@ def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cach if not self.mqa: queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) else: - print("HI") qkv = self.in_proj(x) queries = qkv[..., :-2 * self.head_dim] keys = qkv[..., -2 * self.head_dim : - self.head_dim] From f78ae81324795aaa948fffa1ee1939f3b73ad1a3 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Tue, 26 Mar 2024 04:05:42 +0000 Subject: [PATCH 4/7] Update geglu configs. --- open_lm/model_configs/open_lm_1b_geglu.json | 2 +- open_lm/model_configs/open_lm_1b_geglu_mqa.json | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 open_lm/model_configs/open_lm_1b_geglu_mqa.json diff --git a/open_lm/model_configs/open_lm_1b_geglu.json b/open_lm/model_configs/open_lm_1b_geglu.json index fce2fb17..dd8a18e7 100644 --- a/open_lm/model_configs/open_lm_1b_geglu.json +++ b/open_lm/model_configs/open_lm_1b_geglu.json @@ -1,5 +1,5 @@ { - "hidden_dim": 2048, + "hidden_dim": 1920, "n_layers": 24, "n_heads": 16, "seq_len": 2048, diff --git a/open_lm/model_configs/open_lm_1b_geglu_mqa.json b/open_lm/model_configs/open_lm_1b_geglu_mqa.json new file mode 100644 index 00000000..fce2fb17 --- /dev/null +++ b/open_lm/model_configs/open_lm_1b_geglu_mqa.json @@ -0,0 +1,10 @@ +{ + "hidden_dim": 2048, + "n_layers": 24, + "n_heads": 16, + "seq_len": 2048, + "vocab_size": 50432, + "post_embed_norm": false, + "weight_tying": false, + "ffn_type": "gemma_geglu" +} From c978f10ea9c32afb735d5618b4a4a25941df6f71 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Wed, 27 Mar 2024 19:51:16 -0500 Subject: [PATCH 5/7] MQA compatible with qk norm. --- open_lm/model.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index 114b7fba..7c18049d 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -125,12 +125,11 @@ def __init__(self, layer_id, args: Params): if not self.mqa: self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) else: - self.in_proj = nn.Linear(args.dim, (args.n_heads + 2)* self.head_dim, bias=False) + self.in_proj = nn.Linear(args.dim, (args.n_heads + 2) * self.head_dim, bias=False) self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) self.pos_embed = get_pos_embed(args) self.attn_fn = args.attn_func self.apply_qk_norm = args.apply_qk_norm - # initialize norm layers for queries and keys if needed self.q_norm = ( @@ -143,7 +142,7 @@ def __init__(self, layer_id, args: Params): ) self.k_norm = ( args.norm_type( - args.n_heads * self.head_dim, + args.n_heads * self.head_dim if not self.mqa else self.head_dim, eps=args.norm_eps, ) if self.apply_qk_norm @@ -168,9 +167,9 @@ def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cach queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) else: qkv = self.in_proj(x) - queries = qkv[..., :-2 * self.head_dim] - keys = qkv[..., -2 * self.head_dim : - self.head_dim] - vals = qkv[..., - self.head_dim :] + queries = qkv[..., : -2 * self.head_dim] + keys = qkv[..., -2 * self.head_dim : -self.head_dim] + vals = qkv[..., -self.head_dim :] queries = self.q_norm(queries) keys = self.k_norm(keys) @@ -464,7 +463,7 @@ def create_params(args): moe_capacity_factor=cfg.get("moe_capacity_factor", args.moe_capacity_factor), moe_freq=cfg.get("moe_freq", args.moe_freq), moe_top_k=cfg.get("moe_top_k", args.moe_top_k), - mqa="mqa" in args.attn_name + mqa="mqa" in args.attn_name, ) From 1df785865cd66d6932b9be87ceaa73c75e185629 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Sun, 31 Mar 2024 19:26:25 -0500 Subject: [PATCH 6/7] Add torch_attn_mqa. --- open_lm/attention.py | 2 +- open_lm/params.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index ecde2fb1..0919be7f 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -172,7 +172,7 @@ def get_attn_func( # .view() that collapses last two dimensions fail. One thus needs to # call .contiguous() on the output tensor. [#188] return lambda *args, **kwargs: xformers_attn(*args, **kwargs).contiguous() - elif attn_name == "torch_attn": + elif attn_name == "torch_attn" or attn_name == "torch_attn_mqa": return torch_attn elif attn_name == "custom_attn": assert ( diff --git a/open_lm/params.py b/open_lm/params.py index ee363d4d..419a7151 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -106,7 +106,15 @@ def add_model_args(parser): "--attn-name", type=str, default="auto", - choices=["auto", "xformers_attn", "xformers_attn_variable_length", "torch_attn", "xformers_mqa", "custom_attn"], + choices=[ + "auto", + "xformers_attn", + "xformers_attn_variable_length", + "torch_attn", + "xformers_mqa", + "torch_attn_mqa", + "custom_attn", + ], help="type of attention to use", ) parser.add_argument( From da91d29cd2ebe9f168d4bf08c8ee77c300cc1f49 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Wed, 10 Apr 2024 00:22:18 +0200 Subject: [PATCH 7/7] Update some configs. --- open_lm/model_configs/open_lm_7b_geglu.json | 2 +- open_lm/model_configs/open_lm_7b_geglu_mqa.json | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 open_lm/model_configs/open_lm_7b_geglu_mqa.json diff --git a/open_lm/model_configs/open_lm_7b_geglu.json b/open_lm/model_configs/open_lm_7b_geglu.json index 1f6ba189..8958bede 100644 --- a/open_lm/model_configs/open_lm_7b_geglu.json +++ b/open_lm/model_configs/open_lm_7b_geglu.json @@ -1,5 +1,5 @@ { - "hidden_dim": 4096, + "hidden_dim": 3840, "n_layers": 32, "n_heads": 32, "seq_len": 2048, diff --git a/open_lm/model_configs/open_lm_7b_geglu_mqa.json b/open_lm/model_configs/open_lm_7b_geglu_mqa.json new file mode 100644 index 00000000..1f6ba189 --- /dev/null +++ b/open_lm/model_configs/open_lm_7b_geglu_mqa.json @@ -0,0 +1,10 @@ +{ + "hidden_dim": 4096, + "n_layers": 32, + "n_heads": 32, + "seq_len": 2048, + "vocab_size": 50432, + "post_embed_norm": false, + "weight_tying": false, + "ffn_type": "gemma_geglu" +}