11import torch
22import torch .distributed
33
4+ from torch .nn import functional as F
5+
46from torch import nn
57from transformers .activations import ACT2FN
68from transformers .modeling_utils import PreTrainedModel
@@ -24,13 +26,11 @@ def __init__(
2426 dtype = None ,
2527 ) -> None :
2628 super (FastLinear , self ).__init__ (in_features , out_features , bias , device , dtype )
27- self .swap_dims = True
2829
29- def forward (self , input : torch .Tensor ) -> torch .Tensor :
30- if self .swap_dims :
31- self .weight = nn .Parameter (self .weight .T )
32- self .swap_dims = False
30+ def transpose_weight (self ):
31+ self .weight = nn .Parameter (self .weight .T )
3332
33+ def forward (self , input : torch .Tensor ) -> torch .Tensor :
3434 if self .bias is not None :
3535 return torch .addmm (self .bias , input , self .weight )
3636 return torch .matmul (input , self .weight )
@@ -120,6 +120,10 @@ def __init__(
120120 self .min_id = self .tp_rank * block_size
121121 self .max_id = (self .tp_rank + 1 ) * block_size
122122
123+ # Additional entry that will map to zero
124+ # Used for masking
125+ self .null_idx = block_size
126+
123127 super ().__init__ (
124128 block_size ,
125129 embedding_dim ,
@@ -133,15 +137,19 @@ def __init__(
133137 dtype = dtype ,
134138 )
135139
140+ def add_null_idx (self ):
141+ """Additional 0 entry used for masking"""
142+ self .weight = nn .Parameter (F .pad (self .weight , (0 , 0 , 0 , 1 )))
143+
136144 def forward (self , input : torch .Tensor ) -> torch .Tensor :
137- # `0` if input is in the correct interval, else `1`
138- input_mask = torch .logical_or (self .min_id > input , input >= self .max_id )
145+ # default all out of bounds values to `self.null_idx` that will then be mapped to 0
139146 # translate for [0, self.max_id - self.min_id[
140- input = input - self .min_id
141- # default all out of bounds values to `0`
142- input [input_mask ] = 0
147+ input = torch .where (
148+ (self .min_id > input ) | (input >= self .max_id ),
149+ self .null_idx ,
150+ input - self .min_id ,
151+ )
143152 out = super ().forward (input )
144- out [input_mask ] = 0.0
145153 torch .distributed .all_reduce (out , group = self .process_group )
146154 return out
147155
@@ -214,11 +222,9 @@ def __init__(
214222 hidden_size ,
215223 process_group = process_group ,
216224 )
217- self .swap_dims = True
218225
219- # TODO: remove and swap dims when loading weights
220- def _swap_dims (self ):
221- """Swap dims for the first inference to avoid an additional permute"""
226+ def shuffle_qkv_dims (self ):
227+ """Swap dims to avoid an additional permute"""
222228 self .query_key_value .weight = torch .nn .Parameter (
223229 self .query_key_value .weight .view (
224230 self .num_heads , 3 , self .head_size , self .hidden_size
@@ -231,7 +237,6 @@ def _swap_dims(self):
231237 .permute (1 , 0 , 2 )
232238 .reshape (- 1 )
233239 )
234- self .swap_dims = False
235240
236241 def forward (
237242 self ,
@@ -244,9 +249,6 @@ def forward(
244249 layer_past_present_indices ,
245250 cu_seqlens_q ,
246251 ):
247- if self .swap_dims :
248- self ._swap_dims ()
249-
250252 qkv = self .query_key_value (hidden_states )
251253 qkv = qkv .view (- 1 , 3 , self .num_heads , self .head_size )
252254 qkv_rot = self .rotary_emb (qkv , cos , sin )
@@ -329,7 +331,6 @@ def __init__(self, act, hidden_size, intermediate_size, process_group=None):
329331 hidden_size ,
330332 process_group = process_group ,
331333 )
332- self .heuristic = "auto"
333334 self .process_group = process_group
334335
335336 def forward (self , hidden_states ):
@@ -531,6 +532,25 @@ def __init__(self, config, process_group=None):
531532 self .head_size = self .layers [0 ].attention .head_size
532533 self .num_heads = self .layers [0 ].attention .num_heads
533534
535+ def post_load_weights (self ):
536+ if isinstance (self .embed_in , TensorParallelEmbedding ):
537+ self .embed_in .add_null_idx ()
538+ for layer in self .layers :
539+ layer : FlashNeoXLayer
540+ layer .attention .shuffle_qkv_dims ()
541+ layer .attention .query_key_value .transpose_weight ()
542+ layer .attention .dense .transpose_weight ()
543+ layer .mlp .dense_h_to_4h .transpose_weight ()
544+ layer .mlp .dense_4h_to_h .transpose_weight ()
545+
546+ @classmethod
547+ def from_pretrained (cls , pretrained_model_name_or_path , * model_args , ** kwargs ):
548+ model = super (FlashGPTNeoXModel , cls ).from_pretrained (
549+ pretrained_model_name_or_path , * model_args , ** kwargs
550+ )
551+ model .post_load_weights ()
552+ return model
553+
534554 def forward (
535555 self ,
536556 input_ids ,
@@ -627,6 +647,18 @@ def __init__(self, config):
627647 config .hidden_size , config .vocab_size , bias = False
628648 )
629649
650+ def post_load_weights (self ):
651+ self .gpt_neox .post_load_weights ()
652+ self .embed_out .transpose_weight ()
653+
654+ @classmethod
655+ def from_pretrained (cls , pretrained_model_name_or_path , * model_args , ** kwargs ):
656+ model = super (FlashGPTNeoXForCausalLM , cls ).from_pretrained (
657+ pretrained_model_name_or_path , * model_args , ** kwargs
658+ )
659+ model .post_load_weights ()
660+ return model
661+
630662 def forward (
631663 self ,
632664 input_ids ,
0 commit comments