2222
2323import torch
2424import torch_xla .debug .profiler as xp
25+ from omegaconf import DictConfig
2526from torch import nn
2627from torch .nn import CrossEntropyLoss
2728from transformers .activations import ACT2FN
28- from transformers .modeling_utils import PreTrainedModel
29- from transformers .models .llama .configuration_llama import LlamaConfig
3029from transformers .models .llama .modeling_llama import CausalLMOutputWithPast
3130from transformers .utils import logging
3231
@@ -52,12 +51,7 @@ def forward(self, hidden_states):
5251
5352class LlamaRotaryEmbedding (nn .Module ):
5453 def __init__ (
55- self ,
56- dim ,
57- max_position_embeddings = 2048 ,
58- base = 10000 ,
59- device = None ,
60- scaling_factor = 1.0 ,
54+ self , dim , max_position_embeddings = 2048 , base = 10000 , device = None , scaling_factor = 1.0
6155 ):
6256 super ().__init__ ()
6357 self .scaling_factor = scaling_factor
@@ -161,7 +155,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
161155class LlamaAttention (nn .Module ):
162156 """Multi-headed attention from 'Attention Is All You Need' paper"""
163157
164- def __init__ (self , config : LlamaConfig , layer_idx : int | None = None ):
158+ def __init__ (self , config : DictConfig , layer_idx : int | None = None ):
165159 super ().__init__ ()
166160 self .config = config
167161 self .layer_idx = layer_idx
@@ -290,7 +284,7 @@ def forward(
290284
291285
292286class LlamaDecoderLayer (nn .Module ):
293- def __init__ (self , config : LlamaConfig , layer_idx : int ):
287+ def __init__ (self , config : DictConfig , layer_idx : int ):
294288 super ().__init__ ()
295289 self .hidden_size = config .hidden_size
296290
@@ -338,35 +332,19 @@ def forward(
338332 return hidden_states
339333
340334
341- class LlamaPreTrainedModel (PreTrainedModel ):
342- def _init_weights (self , module ):
343- std = self .config .initializer_range
344- if isinstance (module , nn .Linear ):
345- module .weight .data .normal_ (mean = 0.0 , std = std )
346- if module .bias is not None :
347- module .bias .data .zero_ ()
348- elif isinstance (module , nn .Embedding ):
349- module .weight .data .normal_ (mean = 0.0 , std = std )
350- if module .padding_idx is not None :
351- module .weight .data [module .padding_idx ].zero_ ()
352-
353-
354- class LlamaModel (LlamaPreTrainedModel ):
335+ class LlamaModel (nn .Module ):
355336 """
356337 Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
357338
358339 Args:
359- config: LlamaConfig
340+ config: DictConfig
360341 """
361342
362- def __init__ (self , config : LlamaConfig ):
363- super ().__init__ (config )
364- self .padding_idx = config .pad_token_id
343+ def __init__ (self , config : DictConfig ):
344+ super ().__init__ ()
365345 self .vocab_size = config .vocab_size
366346
367- self .embed_tokens = nn .Embedding (
368- config .vocab_size , config .hidden_size , self .padding_idx
369- )
347+ self .embed_tokens = nn .Embedding (config .vocab_size , config .hidden_size )
370348 self .layers = nn .ModuleList (
371349 [
372350 LlamaDecoderLayer (config , layer_idx )
@@ -375,9 +353,6 @@ def __init__(self, config: LlamaConfig):
375353 )
376354 self .norm = LlamaRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
377355
378- # Initialize weights and apply final processing
379- self .post_init ()
380-
381356 @xp .trace_me ("LlamaModel" )
382357 def forward (
383358 self ,
@@ -393,11 +368,7 @@ def forward(
393368 # Create a causal mask without calling the current method
394369 seq_length = inputs_embeds .size (1 )
395370 causal_mask = torch .triu (
396- torch .full (
397- (seq_length , seq_length ),
398- float ("-inf" ),
399- device = inputs_embeds .device ,
400- ),
371+ torch .full ((seq_length , seq_length ), float ("-inf" ), device = inputs_embeds .device ),
401372 diagonal = 1 ,
402373 )
403374 causal_mask = causal_mask .unsqueeze (0 ).unsqueeze (0 ) # Add batch and head dimension
@@ -411,24 +382,34 @@ def forward(
411382 # decoder layers
412383 for decoder_layer in self .layers :
413384 hidden_states = decoder_layer (
414- hidden_states ,
415- attention_mask = causal_mask ,
416- position_ids = position_ids ,
385+ hidden_states , attention_mask = causal_mask , position_ids = position_ids
417386 )
418387
419388 hidden_states = self .norm (hidden_states )
420389 return hidden_states
421390
422391
423- class LlamaForCausalLM (LlamaPreTrainedModel ):
392+ class LlamaForCausalLM (nn . Module ):
424393 def __init__ (self , config ):
425- super ().__init__ (config )
394+ super ().__init__ ()
395+ self .config = config
426396 self .model = LlamaModel (config )
427397 self .vocab_size = config .vocab_size
428398 self .lm_head = nn .Linear (config .hidden_size , config .vocab_size , bias = False )
429399
430400 # Initialize weights and apply final processing
431- self .post_init ()
401+ self .apply (self ._init_weights )
402+
403+ def _init_weights (self , module ):
404+ std = self .config .initializer_range
405+ if isinstance (module , nn .Linear ):
406+ module .weight .data .normal_ (mean = 0.0 , std = std )
407+ if module .bias is not None :
408+ module .bias .data .zero_ ()
409+ elif isinstance (module , nn .Embedding ):
410+ module .weight .data .normal_ (mean = 0.0 , std = std )
411+ if module .padding_idx is not None :
412+ module .weight .data [module .padding_idx ].zero_ ()
432413
433414 @xp .trace_me ("LlamaForCausalLM" )
434415 def forward (
0 commit comments