diff --git a/open_lm/attention.py b/open_lm/attention.py index f134786c..0919be7f 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -164,7 +164,7 @@ def get_attn_func( ): if attn_name == "auto": return xformers_attn if torch.cuda.is_available() else torch_attn - elif attn_name == "xformers_attn": + elif attn_name == "xformers_attn" or attn_name == "xformers_mqa": return xformers_attn elif attn_name == "xformers_attn_variable_length": # Upon changing the input sequence length, xformers attention changes @@ -172,7 +172,7 @@ def get_attn_func( # .view() that collapses last two dimensions fail. One thus needs to # call .contiguous() on the output tensor. [#188] return lambda *args, **kwargs: xformers_attn(*args, **kwargs).contiguous() - elif attn_name == "torch_attn": + elif attn_name == "torch_attn" or attn_name == "torch_attn_mqa": return torch_attn elif attn_name == "custom_attn": assert ( diff --git a/open_lm/model.py b/open_lm/model.py index 9f5c6397..7c18049d 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -99,6 +99,7 @@ class Params: moe_freq: int = 0 positional_embedding_type: str = "rotary" ffn_type: str = "swiglu" + mqa: bool = False def get_pos_embed(args: Params): @@ -120,7 +121,11 @@ def __init__(self, layer_id, args: Params): super().__init__() self.n_heads = args.n_heads self.head_dim = args.dim // args.n_heads - self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) + self.mqa = args.mqa + if not self.mqa: + self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) + else: + self.in_proj = nn.Linear(args.dim, (args.n_heads + 2) * self.head_dim, bias=False) self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) self.pos_embed = get_pos_embed(args) self.attn_fn = args.attn_func @@ -137,7 +142,7 @@ def __init__(self, layer_id, args: Params): ) self.k_norm = ( args.norm_type( - args.n_heads * self.head_dim, + args.n_heads * self.head_dim if not self.mqa else self.head_dim, eps=args.norm_eps, ) if self.apply_qk_norm @@ -158,14 +163,26 @@ def reset_parameters(self): def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False, attention_mask=None): batchsize, q_len, _ = x.shape - queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) + if not self.mqa: + queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) + else: + qkv = self.in_proj(x) + queries = qkv[..., : -2 * self.head_dim] + keys = qkv[..., -2 * self.head_dim : -self.head_dim] + vals = qkv[..., -self.head_dim :] queries = self.q_norm(queries) keys = self.k_norm(keys) queries = queries.view(batchsize, q_len, self.n_heads, self.head_dim) - keys = keys.view(batchsize, q_len, self.n_heads, self.head_dim) - vals = vals.view(batchsize, q_len, self.n_heads, self.head_dim) + if not self.mqa: + keys = keys.view(batchsize, q_len, self.n_heads, self.head_dim) + vals = vals.view(batchsize, q_len, self.n_heads, self.head_dim) + else: + keys = keys.view(batchsize, q_len, 1, self.head_dim) + vals = vals.view(batchsize, q_len, 1, self.head_dim) + keys = keys.expand(-1, -1, self.n_heads, -1) + vals = keys.expand(-1, -1, self.n_heads, -1) past_length = 0 if past_key_value is None else past_key_value[0].shape[1] queries, keys, vals = self.pos_embed(queries, keys, vals, offset=past_length) @@ -243,8 +260,8 @@ def __init__(self, layer_id, args: Params): self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False) self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) elif args.ffn_type == "gemma_geglu": - # this follows llama / lit llama -- go to multiple of 256 - self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) + # this is from the Griffin paper - hidden_dim is always 3 * dim + self.hidden_dim = 3 * args.dim self.feed_forward = GemmaMLP(args.dim, self.hidden_dim, layer_id) elif args.ffn_type == "moe": moe_args = MoEArgs( @@ -446,6 +463,7 @@ def create_params(args): moe_capacity_factor=cfg.get("moe_capacity_factor", args.moe_capacity_factor), moe_freq=cfg.get("moe_freq", args.moe_freq), moe_top_k=cfg.get("moe_top_k", args.moe_top_k), + mqa="mqa" in args.attn_name, ) diff --git a/open_lm/model_configs/open_lm_1b_geglu.json b/open_lm/model_configs/open_lm_1b_geglu.json new file mode 100644 index 00000000..dd8a18e7 --- /dev/null +++ b/open_lm/model_configs/open_lm_1b_geglu.json @@ -0,0 +1,10 @@ +{ + "hidden_dim": 1920, + "n_layers": 24, + "n_heads": 16, + "seq_len": 2048, + "vocab_size": 50432, + "post_embed_norm": false, + "weight_tying": false, + "ffn_type": "gemma_geglu" +} diff --git a/open_lm/model_configs/open_lm_1b_geglu_mqa.json b/open_lm/model_configs/open_lm_1b_geglu_mqa.json new file mode 100644 index 00000000..fce2fb17 --- /dev/null +++ b/open_lm/model_configs/open_lm_1b_geglu_mqa.json @@ -0,0 +1,10 @@ +{ + "hidden_dim": 2048, + "n_layers": 24, + "n_heads": 16, + "seq_len": 2048, + "vocab_size": 50432, + "post_embed_norm": false, + "weight_tying": false, + "ffn_type": "gemma_geglu" +} diff --git a/open_lm/model_configs/open_lm_7b_geglu.json b/open_lm/model_configs/open_lm_7b_geglu.json new file mode 100644 index 00000000..8958bede --- /dev/null +++ b/open_lm/model_configs/open_lm_7b_geglu.json @@ -0,0 +1,10 @@ +{ + "hidden_dim": 3840, + "n_layers": 32, + "n_heads": 32, + "seq_len": 2048, + "vocab_size": 50432, + "post_embed_norm": false, + "weight_tying": false, + "ffn_type": "gemma_geglu" +} diff --git a/open_lm/model_configs/open_lm_7b_geglu_mqa.json b/open_lm/model_configs/open_lm_7b_geglu_mqa.json new file mode 100644 index 00000000..1f6ba189 --- /dev/null +++ b/open_lm/model_configs/open_lm_7b_geglu_mqa.json @@ -0,0 +1,10 @@ +{ + "hidden_dim": 4096, + "n_layers": 32, + "n_heads": 32, + "seq_len": 2048, + "vocab_size": 50432, + "post_embed_norm": false, + "weight_tying": false, + "ffn_type": "gemma_geglu" +} diff --git a/open_lm/params.py b/open_lm/params.py index 2543fd10..419a7151 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -106,7 +106,15 @@ def add_model_args(parser): "--attn-name", type=str, default="auto", - choices=["auto", "xformers_attn", "xformers_attn_variable_length", "torch_attn", "custom_attn"], + choices=[ + "auto", + "xformers_attn", + "xformers_attn_variable_length", + "torch_attn", + "xformers_mqa", + "torch_attn_mqa", + "custom_attn", + ], help="type of attention to use", ) parser.add_argument(