1
1
from dataclasses import dataclass
2
+ import socket
3
+ from datetime import datetime
2
4
import torch
3
5
from torch import nn
4
6
import torch .nn .functional as F
7
+ from torch .autograd .profiler import record_function
5
8
from typing import Optional
6
9
7
10
@@ -14,8 +17,9 @@ class ModelArgs:
14
17
vocab_size : int = 1024
15
18
ffn_hidden_dim : int = 8192
16
19
norm_eps : float = 1e-5
17
- rope_theta : float = 500000
20
+ rope_theta : float = 10000
18
21
max_seq_len : int = 2048
22
+ max_batch_size : int = 2
19
23
20
24
def __init__ (self , ** kwargs ):
21
25
if self .n_kv_heads is None :
@@ -32,7 +36,7 @@ def forward(self, max_seq_len):
32
36
seq = torch .arange (max_seq_len , device = self .inv_freq .device , dtype = self .inv_freq .dtype )
33
37
freqs = torch .outer (seq , self .inv_freq )
34
38
freqs_cis = torch .polar (torch .ones_like (freqs ), freqs )
35
- return freqs_cis [:, None , None , :]
39
+ return freqs_cis [None , : , None , :]
36
40
37
41
38
42
def apply_rotary_emb (t : torch .Tensor , freqs_cis : torch .Tensor ) -> torch .Tensor :
@@ -43,6 +47,7 @@ def apply_rotary_emb(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
43
47
44
48
class FeedForward (nn .Module ):
45
49
def __init__ (self , dim : int , ffn_hidden_dim : int ):
50
+ super ().__init__ ()
46
51
self .w1 = nn .Linear (dim , ffn_hidden_dim , bias = False )
47
52
self .w2 = nn .Linear (ffn_hidden_dim , dim , bias = False )
48
53
self .w3 = nn .Linear (dim , ffn_hidden_dim , bias = False )
@@ -104,5 +109,86 @@ def __init__(self, params: ModelArgs):
104
109
self .n_layers = params .n_layers
105
110
self .tok_embeddings = nn .Embedding (params .vocab_size , params .dim )
106
111
self .layers = nn .ModuleList (TransformerBlock (params ) for _ in range (params .n_layers ))
107
- self .norm = nn .RMSNorm (params .dim , eps = params .norm_eps )
108
- self .output = nn .Linear (params .dim , params .vocab_size , bias = False )
112
+ self .output_norm = nn .RMSNorm (params .dim , eps = params .norm_eps )
113
+ self .lm_head = nn .Linear (params .dim , params .vocab_size , bias = False )
114
+ self .rotary_embeddings = RotaryEmbedding (params .dim // params .n_heads )
115
+
116
+ def forward (self , inputs : torch .Tensor ):
117
+ bsz , seqlen = inputs .shape
118
+ h = self .tok_embeddings (inputs )
119
+ freqs_cis = self .rotary_embeddings (seqlen )
120
+ mask = torch .full ((seqlen , seqlen ), float ("-inf" ), device = inputs .device )
121
+ mask = torch .triu (mask , diagonal = 1 )
122
+ mask = mask .type_as (h )
123
+ for layer in self .layers :
124
+ h = layer (h , freqs_cis , mask )
125
+ h = self .output_norm (h )
126
+ logits = self .lm_head (h )
127
+ return logits
128
+
129
+
130
+ if __name__ == "__main__" :
131
+ # torch.cuda.memory._record_memory_history()
132
+
133
+ def get_device ():
134
+ if torch .cuda .is_available ():
135
+ return torch .device ("cuda" )
136
+ elif torch .backends .mps .is_available ():
137
+ return torch .device ("mps" )
138
+ return torch .device ("cpu" )
139
+
140
+ def get_mock_batch (batch_size , seq_len , vocab_size ):
141
+ src = torch .randint (0 , vocab_size , (batch_size , seq_len ), device = get_device ())
142
+ tgt = torch .cat ((src [:, 1 :], torch .randint (0 , vocab_size , (batch_size , 1 ), device = get_device ())), dim = 1 )
143
+ return src , tgt
144
+
145
+ def trace_handler (prof : torch .profiler .profile ):
146
+ # Prefix for file names.
147
+ host_name = socket .gethostname ()
148
+ timestamp = datetime .now ().strftime ("%b_%d_%H_%M_%S" )
149
+ file_prefix = f"{ host_name } _{ timestamp } "
150
+
151
+ # Construct the trace file.
152
+ prof .export_chrome_trace (f"{ file_prefix } .json.gz" )
153
+
154
+ # Construct the memory timeline file.
155
+ prof .export_memory_timeline (f"{ file_prefix } .html" , device = "cuda:0" )
156
+
157
+ with torch .profiler .profile (
158
+ activities = [
159
+ torch .profiler .ProfilerActivity .CPU ,
160
+ torch .profiler .ProfilerActivity .CUDA ,
161
+ ],
162
+ schedule = torch .profiler .schedule (wait = 0 , warmup = 0 , active = 3 , repeat = 1 ),
163
+ record_shapes = True ,
164
+ profile_memory = True ,
165
+ with_stack = True ,
166
+ on_trace_ready = trace_handler ,
167
+ ) as prof :
168
+ args = ModelArgs ()
169
+ num_batches = 3
170
+ learning_rate = 1e-5
171
+ model = Llama (args )
172
+ model .to (device = get_device ())
173
+ optimizer = torch .optim .AdamW (model .parameters (), lr = learning_rate )
174
+ criterion = nn .CrossEntropyLoss ()
175
+
176
+ model .train ()
177
+ total_loss = 0
178
+ for i in range (num_batches ):
179
+ prof .step ()
180
+ src , tgt = get_mock_batch (args .max_batch_size , args .max_seq_len , args .vocab_size )
181
+ with record_function ("## forward ##" ):
182
+ logits = model (src )
183
+
184
+ with record_function ("## backward ##" ):
185
+ loss = criterion (logits .view (- 1 , args .vocab_size ), tgt .view (- 1 ))
186
+ loss .backward ()
187
+ with record_function ("## optimizer ##" ):
188
+ optimizer .step ()
189
+ optimizer .zero_grad ()
190
+ total_loss += loss .item ()
191
+ print (f"Step { i + 1 } , Loss: { loss .item ():.4f} " )
192
+
193
+ prof .export_memory_timeline ("llama.html" , device = "cuda:0" )
194
+ # torch.cuda.memory._dump_snapshot("llama.pickle")
0 commit comments